Spaces:
dreroc
/
Running on Zero

yichenchenchen commited on
Commit
ea88892
·
verified ·
1 Parent(s): 352e41f

Upload 25 files

Browse files
src/builder.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from mmengine.registry import Registry
2
+ __all__ = ['BUILDER']
3
+
4
+ BUILDER = Registry('builder')
src/datasets/collate_functions.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from xtuner.utils import IGNORE_INDEX
3
+ from typing import Dict, Sequence
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from functools import partial
6
+ from dataclasses import dataclass
7
+
8
+
9
+ def collate_func_gen(instances: Sequence[Dict],
10
+ pad_index: int = 151645):
11
+ pixel_values_src, pixel_values, input_ids, input_lengths = [], [], [], []
12
+
13
+ for example in instances:
14
+ # 提取图像数据
15
+ if 'pixel_values_src' in example:
16
+ pixel_values_src.append(example.pop('pixel_values_src'))
17
+ if 'pixel_values' in example:
18
+ pixel_values.append(example.pop('pixel_values'))
19
+
20
+ input_lengths.append(len(example['input_ids']))
21
+ input_ids.append(example.pop('input_ids'))
22
+
23
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
24
+ attention_mask = torch.zeros_like(input_ids).bool()
25
+ for i in range(len(input_ids)):
26
+ attention_mask[i, :input_lengths[i]] = True
27
+
28
+ data_dict = {
29
+ 'input_ids': input_ids,
30
+ 'attention_mask': attention_mask,
31
+ }
32
+
33
+ if pixel_values:
34
+ data_dict['pixel_values'] = torch.stack(pixel_values)
35
+ if pixel_values_src:
36
+ data_dict['pixel_values_src'] = torch.stack(pixel_values_src)
37
+
38
+ return {'data': data_dict, 'data_samples': None}
39
+
40
+
41
+ def collate_func_und(instances, pad_index=151645):
42
+ input_ids_list, labels_list, pixel_values_list = [], [], []
43
+
44
+ for sample in instances:
45
+ input_ids_list.append(torch.LongTensor(sample['input_ids']))
46
+ labels_list.append(torch.LongTensor(sample['labels']))
47
+
48
+ if 'pixel_values' in sample:
49
+ pixel_values_list.append(sample['pixel_values'])
50
+
51
+ ori_length = [len(input_ids_) for input_ids_ in input_ids_list]
52
+ # right padding
53
+ if len(instances) > 1:
54
+ input_ids = pad_sequence(
55
+ input_ids_list, batch_first=True, padding_value=pad_index)
56
+ labels = pad_sequence(
57
+ labels_list, batch_first=True, padding_value=IGNORE_INDEX)
58
+ else:
59
+ input_ids = torch.stack(input_ids_list)
60
+ labels = torch.stack(labels_list)
61
+
62
+ attention_mask = torch.zeros_like(input_ids).bool()
63
+ for i, length in enumerate(ori_length):
64
+ attention_mask[i, :length] = True # right padding
65
+
66
+ data_dict = {
67
+ 'input_ids': input_ids,
68
+ 'attention_mask': attention_mask,
69
+ 'labels': labels,
70
+ 'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None
71
+ }
72
+
73
+ return {'data': data_dict, 'data_samples': None}
74
+
75
+
76
+ class CollateConcat(object):
77
+ def __init__(self, collate_fns, keys):
78
+ self.keys = keys
79
+ self.collate_fns = {}
80
+ for key, collate_fn in zip(keys, collate_fns):
81
+ func = collate_fn.pop('type')
82
+ self.collate_fns[key] = partial(func, **collate_fn)
83
+
84
+ def __call__(self, data_samples):
85
+ data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0]
86
+ data_dict = {}
87
+ key = data_samples[0]['type']
88
+ data_dict[key] = self.collate_fns[key](data_samples)['data']
89
+
90
+ return {'data': data_dict, 'data_samples': None}
src/datasets/samplers/multi_source_sampler.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import itertools
3
+ from typing import Iterator, List, Optional, Sized, Union
4
+ import torch
5
+ from mmengine.dist import get_dist_info, sync_random_seed
6
+ from torch.utils.data import Sampler
7
+
8
+
9
+ class FixedBatchMultiSourceSampler(Sampler):
10
+ r"""Multi-Source Infinite Sampler.
11
+
12
+ According to the sampling ratio, sample data from different
13
+ datasets to form batches.
14
+
15
+ Args:
16
+ repeat (tuple): repeat factor
17
+ dataset (Sized): The dataset.
18
+ batch_size (int): Size of mini-batch.
19
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
20
+ seed (int, optional): Random seed. If None, set a random seed.
21
+ Defaults to None.
22
+ """
23
+
24
+ def __init__(self,
25
+ repeat,
26
+ dataset: Sized,
27
+ batch_size: int,
28
+ shuffle: bool = True,
29
+ seed: Optional[int] = None) -> None:
30
+
31
+ assert hasattr(dataset, 'cumulative_sizes'),\
32
+ f'The dataset must be ConcatDataset, but get {dataset}'
33
+ assert isinstance(batch_size, int) and batch_size > 0, \
34
+ 'batch_size must be a positive integer value, ' \
35
+ f'but got batch_size={batch_size}'
36
+ assert len(repeat) == len(dataset.cumulative_sizes), \
37
+ 'The length of repeat must be equal to ' \
38
+ f'the number of datasets, but got repeat={repeat}'
39
+
40
+ rank, world_size = get_dist_info()
41
+ self.rank = rank
42
+ self.world_size = world_size
43
+
44
+ self.dataset = dataset
45
+ self.repeat = repeat
46
+ self.cumulative_sizes = [0] + dataset.cumulative_sizes
47
+ self.batch_size = batch_size
48
+
49
+ self.seed = sync_random_seed() if seed is None else seed
50
+ self.shuffle = shuffle
51
+ self.source2inds = {
52
+ source: self._indices_of_rank(len(ds))
53
+ for source, ds in enumerate(dataset.datasets)
54
+ }
55
+
56
+ def _infinite_indices(self, sample_size: int) -> Iterator[int]:
57
+ """Infinitely yield a sequence of indices."""
58
+ g = torch.Generator()
59
+ g.manual_seed(self.seed)
60
+ while True:
61
+ if self.shuffle:
62
+ yield from torch.randperm(sample_size, generator=g).tolist()
63
+ else:
64
+ yield from torch.arange(sample_size).tolist()
65
+
66
+ def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
67
+ """Slice the infinite indices by rank."""
68
+ yield from itertools.islice(
69
+ self._infinite_indices(sample_size), self.rank, None,
70
+ self.world_size)
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.dataset)
74
+
75
+ def set_epoch(self, epoch: int) -> None:
76
+ """Not supported in `epoch-based runner."""
77
+ pass
78
+
79
+ def __iter__(self) -> Iterator[int]:
80
+ while True:
81
+ for source, repeat in enumerate(self.repeat):
82
+ for _ in range(repeat):
83
+ batch_buffer_per_source = []
84
+ while len(batch_buffer_per_source) < self.batch_size:
85
+ idx = next(self.source2inds[source])
86
+ idx += self.cumulative_sizes[source]
87
+ batch_buffer_per_source.append(idx)
88
+
89
+ yield from batch_buffer_per_source
90
+
91
+
92
+ class MultiSourceSampler(Sampler):
93
+ def __init__(self,
94
+ repeats,
95
+ dataset: Sized,
96
+ batch_sizes: list[int],
97
+ shuffle: bool = True,
98
+ seed: Optional[int] = None) -> None:
99
+
100
+ assert hasattr(dataset, 'cumulative_sizes'),\
101
+ f'The dataset must be ConcatDataset, but get {dataset}'
102
+
103
+ assert isinstance(batch_sizes, list), \
104
+ f'source_ratio must be a list, but got batch_sizes={batch_sizes}'
105
+ assert len(batch_sizes) == len(dataset.cumulative_sizes), \
106
+ 'The length of batch_sizes must be equal to ' \
107
+ f'the number of datasets, but got batch_sizes={batch_sizes}'
108
+
109
+ rank, world_size = get_dist_info()
110
+ self.rank = rank
111
+ self.world_size = world_size
112
+
113
+ self.dataset = dataset
114
+ self.cumulative_sizes = [0] + dataset.cumulative_sizes
115
+ self.batch_sizes = batch_sizes
116
+
117
+
118
+ self.seed = sync_random_seed() if seed is None else seed
119
+ self.shuffle = shuffle
120
+ self.source2inds = {
121
+ source: self._indices_of_rank(len(ds))
122
+ for source, ds in enumerate(dataset.datasets)
123
+ }
124
+
125
+ self.repeats = repeats
126
+ assert len(self.repeats) == len(self.batch_sizes)
127
+
128
+ def _infinite_indices(self, sample_size: int) -> Iterator[int]:
129
+ """Infinitely yield a sequence of indices."""
130
+ g = torch.Generator()
131
+ g.manual_seed(self.seed)
132
+ while True:
133
+ if self.shuffle:
134
+ yield from torch.randperm(sample_size, generator=g).tolist()
135
+ else:
136
+ yield from torch.arange(sample_size).tolist()
137
+
138
+ def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
139
+ """Slice the infinite indices by rank."""
140
+ yield from itertools.islice(
141
+ self._infinite_indices(sample_size), self.rank, None,
142
+ self.world_size)
143
+
144
+
145
+ def __len__(self) -> int:
146
+ return len(self.dataset)
147
+
148
+ def set_epoch(self, epoch: int) -> None:
149
+ """Not supported in `epoch-based runner."""
150
+ pass
151
+
152
+ def __iter__(self) -> Iterator[int]:
153
+ while True:
154
+ for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
155
+ for _ in range(repeat):
156
+ batch_buffer_per_source = []
157
+ while len(batch_buffer_per_source) < batch_size:
158
+ idx = next(self.source2inds[source])
159
+ idx += self.cumulative_sizes[source]
160
+ batch_buffer_per_source.append(idx)
161
+
162
+ yield from batch_buffer_per_source
163
+
164
+ @property
165
+ def batch_size(self):
166
+ batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
167
+ batch_size_ave = batch_size_sum // sum(self.repeats)
168
+
169
+ return batch_size_ave
170
+
171
+
172
+ class MultiSourceBatchSampler(Sampler[list[int]]):
173
+ def __init__(
174
+ self,
175
+ sampler: Union[FixedBatchMultiSourceSampler, MultiSourceSampler],
176
+ batch_sizes: list[int],
177
+ repeats: list[int],
178
+ **kwargs
179
+ ) -> None:
180
+ self.sampler = sampler
181
+ self.batch_sizes = batch_sizes
182
+ self.repeats = repeats
183
+
184
+ def __iter__(self) -> Iterator[list[int]]:
185
+ # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
186
+ sampler_iter = iter(self.sampler)
187
+
188
+ while True:
189
+ for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
190
+ for _ in range(repeat):
191
+ batch = [*itertools.islice(sampler_iter, batch_size)]
192
+ yield batch
193
+
194
+ @property
195
+ def batch_size(self):
196
+ batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
197
+ batch_size_ave = batch_size_sum // sum(self.repeats)
198
+
199
+ return batch_size_ave
200
+
201
+ def __len__(self) -> int:
202
+ return len(self.sampler) // self.batch_size
src/datasets/text2image/__init__.py ADDED
File without changes
src/datasets/text2image/text2image.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ import os
4
+ import json
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ from einops import rearrange
9
+ from xtuner.registry import BUILDER
10
+ from mmengine.registry import DATASETS
11
+ from src.datasets.utils import crop2square
12
+ from glob import glob
13
+ from typing import List, Dict, Any, Optional
14
+ import mmap
15
+ import struct
16
+ from src.datasets.utils import crop2square, encode_fn
17
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
18
+
19
+
20
+ @BUILDER.register_module()
21
+ class Text2ImageDataset(Dataset):
22
+ def __init__(self,
23
+ data_path,
24
+ local_folder,
25
+ image_size,
26
+ unconditional=0.1,
27
+ tokenizer=None,
28
+ prompt_template=None,
29
+ max_length=1024,
30
+ crop_image=True,
31
+ cap_source='caption',
32
+ ):
33
+ super().__init__()
34
+ self.data_path = data_path
35
+ self._load_data(data_path)
36
+ self.unconditional = unconditional
37
+ self.local_folder = local_folder
38
+ self.cap_source = cap_source
39
+ self.image_size = image_size
40
+ self.tokenizer = BUILDER.build(tokenizer)
41
+
42
+ self.prompt_template = prompt_template
43
+ self.max_length = max_length
44
+ self.crop_image = crop_image
45
+ self.metainfo = {'task': 'unified'}
46
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
47
+
48
+
49
+
50
+ def _load_data(self, data_path):
51
+ with open(data_path, 'r') as f:
52
+ self.data_list = json.load(f)
53
+
54
+ print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
55
+
56
+ def full_init(self):
57
+ """Dummy full_init to be compatible with MMEngine ConcatDataset."""
58
+ return
59
+
60
+
61
+ def __len__(self):
62
+ return len(self.data_list)
63
+
64
+ def _read_image(self, image_file):
65
+ image = Image.open(os.path.join(self.local_folder, image_file))
66
+ assert image.width > 8 and image.height > 8, f"Image: {image.size}"
67
+ assert image.width / image.height > 0.1, f"Image: {image.size}"
68
+ assert image.width / image.height < 10, f"Image: {image.size}"
69
+ return image
70
+
71
+ def _process_text(self, text):
72
+ if random.uniform(0, 1) < self.unconditional:
73
+ prompt = "Generate an image."
74
+ else:
75
+ prompt = f"Generate an image: {text.strip()}"
76
+ prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
77
+ input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]
78
+
79
+ return dict(input_ids=input_ids[:self.max_length])
80
+
81
+ def _process_image(self, image):
82
+ data = dict()
83
+
84
+ if self.crop_image:
85
+ image = crop2square(image)
86
+ else:
87
+ target_size = max(image.size)
88
+ image = image.resize(size=(target_size, target_size))
89
+
90
+ image = image.resize(size=(self.image_size, self.image_size))
91
+ pixel_values = torch.from_numpy(np.array(image)).float()
92
+ pixel_values = pixel_values / 255
93
+ pixel_values = 2 * pixel_values - 1
94
+ pixel_values = rearrange(pixel_values, 'h w c -> c h w')
95
+
96
+ data.update(pixel_values=pixel_values)
97
+
98
+ return data
99
+
100
+ def _retry(self):
101
+ return self.__getitem__(random.choice(range(self.__len__())))
102
+
103
+ def __getitem__(self, idx):
104
+ try:
105
+ data_sample = self.data_list[idx]
106
+ image = self._read_image(data_sample['image']).convert('RGB')
107
+
108
+ caption = data_sample[self.cap_source]
109
+ data = self._process_image(image)
110
+ data.update(self._process_text(caption))
111
+ data.update(type='text2image')
112
+
113
+ return data
114
+
115
+ except Exception as e:
116
+ print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True)
117
+ return self._retry()
118
+
119
+ @DATASETS.register_module()
120
+ @BUILDER.register_module()
121
+ class LargeText2ImageDataset(Text2ImageDataset):
122
+ # self.data_list only contains paths of images and captions
123
+
124
+ def __init__(self, cap_folder=None, *args, **kwargs):
125
+ super().__init__(*args, **kwargs)
126
+ self.cap_folder = self.local_folder if cap_folder is None else cap_folder
127
+
128
+ def _load_data(self, data_path): # image path and annotation path are saved in a json file
129
+ if data_path.endswith(".json"):
130
+ with open(data_path, 'r') as f:
131
+ self.data_list = json.load(f)
132
+ else:
133
+ self.data_list = []
134
+ json_files = glob(f'{data_path}/*.json')
135
+ for json_file in json_files:
136
+ with open(json_file, 'r') as f:
137
+ self.data_list += json.load(f)
138
+
139
+ print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
140
+
141
+ def __getitem__(self, idx):
142
+ try:
143
+ data_sample = self.data_list[idx]
144
+ image = self._read_image(data_sample['image']).convert('RGB')
145
+ with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f:
146
+ caption = json.load(f)[self.cap_source]
147
+ data = self._process_image(image)
148
+ data.update(self._process_text(caption))
149
+ data.update(type='text2image')
150
+ return data
151
+
152
+ except Exception as e:
153
+ print(f"Error when reading {self.data_path}:{data_sample}: {e}", flush=True)
154
+ return self._retry()
155
+
156
+
157
+ @DATASETS.register_module()
158
+ @BUILDER.register_module()
159
+ class MMapT2IDataset(Dataset):
160
+ """
161
+ Map-style Text2Image Dataset with mmap-based random access.
162
+ 一次性在 __init__ 打开 mmap;__getitem__ O(1) 读取指定行。
163
+ """
164
+ def __init__(
165
+ self,
166
+ jsonl_path: str,
167
+ idx_path: str,
168
+ image_size: int,
169
+ tokenizer: Optional[Dict] = None,
170
+ template_map_fn: Optional[Dict] = None,
171
+ cap_source: str = "prompt",
172
+ max_length: int = 2048,
173
+ image_length: int = 512,
174
+ unconditional: float = 0.01,
175
+ crop_image: bool = False,
176
+ ):
177
+ super().__init__()
178
+
179
+ # ---------- 基础参数 ----------
180
+ self.jsonl_path = jsonl_path
181
+ self.image_size = image_size
182
+ self.cap_source = cap_source
183
+ self.max_length = max_length
184
+ self.unconditional = unconditional
185
+ self.crop_image = crop_image
186
+
187
+ # ---------- tokenizer / template ----------
188
+ self.tokenizer = BUILDER.build(tokenizer)
189
+ self.template_map_fn = template_map_fn
190
+
191
+ # ---------- mmap 加载 ----------
192
+ self._open_mmap(jsonl_path, idx_path)
193
+ self.metainfo = {'task' :'unified'}
194
+ # ===== mmap & index =====
195
+ def _open_mmap(self, jsonl_path: str, idx_path: str):
196
+ # mmap 文件
197
+ self._jsonl_fp = open(jsonl_path, "r+b")
198
+ self._mm = mmap.mmap(self._jsonl_fp.fileno(), 0, access=mmap.ACCESS_READ)
199
+
200
+ # 读取 offset 索引
201
+ with open(idx_path, "rb") as f:
202
+ nlines = struct.unpack("<Q", f.read(8))[0]
203
+ self._offsets = np.frombuffer(f.read(8 * nlines), dtype=np.uint64)
204
+ print(f"[MMapT2IDataset] {jsonl_path}: {nlines} lines indexed")
205
+
206
+ def __len__(self) -> int:
207
+ return self._offsets.size
208
+
209
+ def full_init(self):
210
+ """Dummy full_init to be compatible with MMEngine ConcatDataset."""
211
+ return
212
+ def _read_line(self, idx: int) -> str:
213
+ off = int(self._offsets[idx])
214
+ self._mm.seek(off)
215
+ return self._mm.readline().decode("utf-8")
216
+
217
+ # ===== 核心处理 =====
218
+ def _load_image(self, path: str) -> torch.Tensor:
219
+ img = Image.open(path).convert("RGB")
220
+
221
+ # 预处理:裁剪成方形 / pad
222
+ if self.crop_image:
223
+ img = crop2square(img)
224
+ else:
225
+ target_size = max(img.size)
226
+ img = img.resize((target_size, target_size))
227
+
228
+ img = img.resize((self.image_size, self.image_size))
229
+ arr = np.asarray(img, dtype=np.uint8) # HWC uint8
230
+ px = torch.as_tensor(arr).float() / 255.0 # 0-1
231
+ px = 2 * px - 1 # -1 ~ 1
232
+ return rearrange(px, "h w c -> c h w") # CHW
233
+
234
+ def _build_prompt(self, caption: str) -> torch.Tensor:
235
+ if random.random() < self.unconditional:
236
+ caption = "Generate an image."
237
+ else:
238
+ caption = f"Generate an image: {caption.strip()}"
239
+
240
+ instr = self.template_map_fn["INSTRUCTION"].format(input=caption)
241
+ ids = self.tokenizer.encode(
242
+ instr, add_special_tokens=True, return_tensors="pt"
243
+ )[0][: self.max_length]
244
+ return ids
245
+
246
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
247
+ # 1) 取 jsonl 行
248
+ sample = json.loads(self._read_line(idx))
249
+
250
+ # 2) 加载 & 处理图像
251
+ pixel_values = self._load_image(sample["image"])
252
+
253
+ # 3) 处理文本
254
+ caption = sample.get(self.cap_source, "")
255
+ input_ids = self._build_prompt(caption)
256
+
257
+ # 4) 打包
258
+ data = dict(
259
+ pixel_values=pixel_values,
260
+ input_ids=input_ids,
261
+ type="text2image",
262
+ image_file=sample["image"],
263
+ idx=idx,
264
+ )
265
+ return data
266
+
267
+
268
+ @DATASETS.register_module()
269
+ @BUILDER.register_module()
270
+ class ReconstructDataset(Dataset):
271
+ def __init__(self,
272
+ data_path: str,
273
+ image_size: int,
274
+ tokenizer=None,
275
+ prompt_template=None,
276
+ cap_source: str = "prompt",
277
+ max_length: int = 8192,
278
+ crop_image: bool = True,
279
+ img_prefix: str = ""):
280
+ super().__init__()
281
+ self.image_size = image_size
282
+ self.tokenizer = BUILDER.build(tokenizer)
283
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
284
+ self.prompt_template = prompt_template
285
+ self.cap_source = cap_source
286
+ self.max_length = max_length
287
+ self.crop_image = crop_image
288
+ self.img_prefix = img_prefix
289
+ self._load_data(data_path)
290
+
291
+ m = n = self.image_size // 16
292
+ self.image_token_repeat = m * n + 64
293
+ self.metainfo = {'task': 'unified'}
294
+
295
+ def full_init(self):
296
+ """Dummy full_init to be compatible with MMEngine ConcatDataset."""
297
+ return
298
+
299
+ def _load_data(self, path):
300
+ with open(path) as f:
301
+ self.data_list = [json.loads(l) for l in f]
302
+ print(f"[I2ICaptionReconstructDataset] Loaded {len(self.data_list)} samples from {path}")
303
+
304
+ def _add_prefix(self, rel):
305
+ return os.path.join(self.img_prefix, rel.lstrip("/")) if self.img_prefix else rel
306
+
307
+ def _read_image(self, path):
308
+ img = Image.open(path).convert("RGB")
309
+ assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10
310
+ return img
311
+
312
+ # ---------- preprocess ----------
313
+ def _process_image(self, img):
314
+ img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2)
315
+ img = img.resize((self.image_size, self.image_size))
316
+ px = torch.from_numpy(np.array(img)).float() / 255.
317
+ px = 2 * px - 1
318
+ return rearrange(px, "h w c -> c h w")
319
+
320
+ def _encode_prompt(self, text):
321
+ # for bad_token in ["[IMAGE]", "<image_placeholder>", "<image_plaeholder>"]:
322
+ # text = text.replace(bad_token, "")
323
+ text = "Repeat this image."
324
+ prompt_in = f"<image>\n{text.strip()}"
325
+ prompt = self.prompt_template["INSTRUCTION"].format(input=prompt_in)
326
+ prompt = prompt.replace("<image>", "<image>" * self.image_token_repeat)
327
+ input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")[0]
328
+ mask = (input_ids != self.tokenizer.pad_token_id).long()
329
+ return input_ids[:self.max_length], mask[:self.max_length]
330
+
331
+ def __len__(self):
332
+ return len(self.data_list)
333
+
334
+ def _retry(self):
335
+ return self.__getitem__(random.randrange(len(self)))
336
+
337
+ def __getitem__(self, idx):
338
+ try:
339
+ sample = self.data_list[idx]
340
+ src_img = self._read_image(self._add_prefix(sample["image"]))
341
+ tgt_img = src_img
342
+ caption = sample[self.cap_source]
343
+
344
+ px_src = self._process_image(src_img)
345
+ px_tgt = self._process_image(tgt_img)
346
+ input_ids, mask = self._encode_prompt(caption)
347
+
348
+ return {
349
+ "pixel_values_src": px_src,
350
+ "pixel_values": px_tgt,
351
+ "input_ids": input_ids,
352
+ "attention_mask": mask,
353
+ "type": "image_edit"
354
+ }
355
+ except Exception as e:
356
+ print(f"[I2ICaptionReconstructDataset] Error @ {idx}: {e}")
357
+ return self._retry()
358
+
359
+ @DATASETS.register_module()
360
+ @BUILDER.register_module()
361
+ class UncondReconstructDataset(Dataset):
362
+ def __init__(self,
363
+ data_path: str,
364
+ image_size: int,
365
+ tokenizer=None,
366
+ prompt_template=None,
367
+ cap_source: str = "prompt",
368
+ max_length: int = 8192,
369
+ crop_image: bool = True,
370
+ img_prefix: str = ""):
371
+ super().__init__()
372
+ self.image_size = image_size
373
+ self.tokenizer = BUILDER.build(tokenizer)
374
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
375
+ self.prompt_template = prompt_template
376
+ self.max_length = max_length
377
+ self.crop_image = crop_image
378
+ self.img_prefix = img_prefix
379
+ self.cap_source = cap_source
380
+
381
+
382
+ self._load_data(data_path)
383
+
384
+ # 计算 image token 展开数量
385
+ m = n = self.image_size // 16
386
+ self.image_token_repeat = m * n + 64
387
+ self.metainfo = {'task': 'unified'}
388
+
389
+ def _load_data(self, path):
390
+ with open(path) as f:
391
+ self.data_list = [json.loads(l) for l in f]
392
+ print(f"[I2IUncondReconstructDataset] Loaded {len(self.data_list)} samples from {path}")
393
+
394
+ def _add_prefix(self, rel_path):
395
+ return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path
396
+
397
+ def full_init(self):
398
+ """Dummy full_init to be compatible with MMEngine ConcatDataset."""
399
+ return
400
+ def _read_image(self, path):
401
+ image = Image.open(path).convert("RGB")
402
+ assert image.width > 8 and image.height > 8 and 0.1 < image.width / image.height < 10
403
+ return image
404
+
405
+
406
+ # ---------- preprocess ----------
407
+ def _process_image(self, img):
408
+ img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2)
409
+ img = img.resize((self.image_size, self.image_size))
410
+ px = torch.from_numpy(np.array(img)).float() / 255.
411
+ px = 2 * px - 1
412
+ return rearrange(px, "h w c -> c h w")
413
+
414
+ def __len__(self):
415
+ return len(self.data_list)
416
+
417
+ def _retry(self, max_tries=5):
418
+ for _ in range(max_tries):
419
+ try:
420
+ return self.__getitem__(random.randrange(len(self)))
421
+ except Exception:
422
+ continue
423
+ raise RuntimeError("Exceeded max retries in I2IUncondReconstructDataset")
424
+
425
+ def __getitem__(self, idx):
426
+ try:
427
+ sample = self.data_list[idx]
428
+ path = self._add_prefix(sample["image"])
429
+ img = self._read_image(path)
430
+ px = self._process_image(img)
431
+
432
+ # ==== 填入空文本 ====
433
+ input_ids = torch.zeros(0, dtype=torch.long)
434
+ attention_mask = torch.zeros(0, dtype=torch.long)
435
+
436
+ return {
437
+ "pixel_values_src": px,
438
+ "pixel_values": px.clone(),
439
+ "type": "image_edit",
440
+ "input_ids": input_ids,
441
+ "attention_mask": attention_mask,
442
+ # 重建任务不再输出 input_ids / attention_mask
443
+ }
444
+ except Exception as e:
445
+ print(f"[I2IUncondReconstructDataset] Error @ {idx}: {e}")
446
+ return self._retry()
447
+
448
+
449
+
450
+ @DATASETS.register_module()
451
+ @BUILDER.register_module()
452
+ class Text2ImageJSONLDataset(Dataset):
453
+ def __init__(self,
454
+ data_path,
455
+ image_size,
456
+ tokenizer=None,
457
+ prompt_template=None,
458
+ cap_source='prompt',
459
+ max_length=1024,
460
+ unconditional=0.1,
461
+ crop_image=True,
462
+ ):
463
+ super().__init__()
464
+ self.data_path = data_path
465
+ self._load_data(data_path)
466
+ self.image_size = image_size
467
+ self.tokenizer = BUILDER.build(tokenizer)
468
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
469
+ self.prompt_template = prompt_template
470
+ self.cap_source = cap_source
471
+ self.max_length = max_length
472
+ self.unconditional = unconditional
473
+ self.crop_image = crop_image
474
+ self.metainfo = {'task': 'unified'}
475
+
476
+ def _load_data(self, data_path):
477
+ self.data_list = []
478
+ with open(data_path, 'r') as f:
479
+ for line in f:
480
+ self.data_list.append(json.loads(line.strip()))
481
+ print(f"Loaded {len(self.data_list)} samples from {data_path}")
482
+
483
+ def full_init(self):
484
+ """Dummy full_init for MMEngine ConcatDataset compatibility."""
485
+ pass
486
+ def __len__(self):
487
+ return len(self.data_list)
488
+
489
+ def _read_image(self, image_file):
490
+ image = Image.open(image_file).convert('RGB')
491
+ assert image.width > 8 and image.height > 8
492
+ assert 0.1 < image.width / image.height < 10
493
+ return image
494
+
495
+ def _process_image(self, image):
496
+ if self.crop_image:
497
+ image = crop2square(image)
498
+ else:
499
+ target_size = max(image.size)
500
+ image = image.resize((target_size, target_size))
501
+
502
+ image = image.resize((self.image_size, self.image_size))
503
+ pixel_values = torch.from_numpy(np.array(image)).float() / 255.0
504
+ pixel_values = 2 * pixel_values - 1 # [-1, 1]
505
+ pixel_values = rearrange(pixel_values, 'h w c -> c h w')
506
+ return dict(pixel_values=pixel_values)
507
+
508
+ def _process_text(self, text):
509
+ if random.uniform(0, 1) < self.unconditional:
510
+ text = "Generate an image."
511
+ else:
512
+ text = f"Generate an image: {text.strip()}"
513
+ prompt = self.prompt_template['INSTRUCTION'].format(input=text)
514
+ input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]
515
+ return dict(input_ids=input_ids[:self.max_length])
516
+
517
+ def _retry(self):
518
+ return self.__getitem__(random.randint(0, len(self.data_list) - 1))
519
+
520
+ def __getitem__(self, idx):
521
+ try:
522
+ sample = self.data_list[idx]
523
+ image = self._read_image(sample['image'])
524
+ caption = sample[self.cap_source]
525
+ data = self._process_image(image)
526
+ data.update(self._process_text(caption))
527
+ data.update(type='text2image')
528
+ return data
529
+ except Exception as e:
530
+ print(f"[JSONLDataset] Error reading sample #{idx}: {e}")
531
+ return self._retry()
532
+
533
+
534
+
535
+ # 纯文生图没有占位符的问题,下面编辑数据集需要考虑占位符
536
+ @DATASETS.register_module()
537
+ @BUILDER.register_module()
538
+ class ImageEditJSONLDataset(Dataset):
539
+ """
540
+ Dataset for <src, tgt, prompt> image editing, now decoupled from tokenization logic.
541
+ """
542
+ def __init__(self,
543
+ data_path: str,
544
+ image_size: int,
545
+ tokenizer=None,
546
+ prompt_template=None,
547
+ max_length: int = 8192,
548
+ cap_source: str = "prompt",
549
+ unconditional: float = 0,
550
+ crop_image: bool = False,
551
+ img_prefix: str = ""):
552
+ super().__init__()
553
+ self.data_path = data_path
554
+ self.image_size = image_size
555
+ self.tokenizer = BUILDER.build(tokenizer)
556
+ self.prompt_template = prompt_template
557
+ self.max_length = max_length
558
+ self.cap_source = cap_source
559
+ self.unconditional = unconditional
560
+ self.crop_image = crop_image
561
+ self.img_prefix = img_prefix
562
+ self._load_data(data_path)
563
+ # Calculate image token repetition length, consistent with inference.
564
+ m = n = self.image_size // 16
565
+ self.image_token_repeat = m * n + 64
566
+ self.metainfo = {'task': 'unified'}
567
+
568
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
569
+ self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
570
+ print(f"Registered <image> token at index {self.image_token_idx}")
571
+
572
+ def _load_data(self, path):
573
+ with open(path) as f:
574
+ self.data_list = [json.loads(l) for l in f]
575
+ print(f"[ImageEditJSONLDataset] Loaded {len(self.data_list)} samples from {path}")
576
+
577
+ def full_init(self):
578
+ """Dummy full_init for MMEngine ConcatDataset compatibility."""
579
+ pass
580
+
581
+ def _add_prefix(self, rel_path):
582
+ return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path
583
+
584
+ def _read_image(self, path):
585
+ path = path.replace("datasets_vlm02", "datasets_vlm")
586
+ img = Image.open(path).convert("RGB")
587
+ assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10
588
+ return img
589
+
590
+ def _process_image(self, img):
591
+ img = crop2square(img) if self.crop_image else img.resize((max(img.size),) * 2)
592
+ img = img.resize((self.image_size, self.image_size))
593
+ px = torch.from_numpy(np.array(img)).float() / 255.
594
+ px = 2 * px - 1
595
+ return rearrange(px, "h w c -> c h w")
596
+
597
+ # --- REFACTORED: This method now only prepares the raw prompt text ---
598
+ def _prepare_prompt_text(self, raw_text: str):
599
+ """Cleans text and handles unconditional generation."""
600
+
601
+ for bad_token in ["[IMAGE]", "<image_placeholder>", "<image_plaeholder>", "<image>"]:
602
+ txt = raw_text.replace(bad_token, "")
603
+ txt = txt.strip()
604
+
605
+ if random.random() < self.unconditional:
606
+ txt = "Edit this image."
607
+ return txt
608
+
609
+ def _retry(self):
610
+ return self.__getitem__(random.randrange(len(self)))
611
+
612
+ def __len__(self):
613
+ return len(self.data_list)
614
+
615
+ def __getitem__(self, idx):
616
+ try:
617
+ sample = self.data_list[idx]
618
+ src_path, tgt_path = map(self._add_prefix, [sample["images"][0], sample["image"]])
619
+ src_img, tgt_img = map(self._read_image, [src_path, tgt_path])
620
+
621
+ px_src, px_tgt = map(self._process_image, [src_img, tgt_img])
622
+
623
+ # --- MODIFIED: Call the unified encode_fn ---
624
+ # 1. Prepare the raw prompt string
625
+ prompt_text = self._prepare_prompt_text(sample[self.cap_source])
626
+
627
+ # 2. Delegate all encoding and formatting to encode_fn
628
+ encoded_text = encode_fn(
629
+ example=prompt_text,
630
+ tokenizer=self.tokenizer,
631
+ prompt_template=self.prompt_template,
632
+ max_length=self.max_length,
633
+ image_length=self.image_token_repeat,
634
+ image_token_idx=self.image_token_idx
635
+ )
636
+
637
+ return {
638
+ "pixel_values_src": px_src,
639
+ "pixel_values": px_tgt,
640
+ "input_ids": torch.tensor(encoded_text["input_ids"], dtype=torch.long),
641
+ "attention_mask": torch.tensor(encoded_text["attention_mask"], dtype=torch.long),
642
+ "type": "image_edit",
643
+ }
644
+ except Exception as e:
645
+ print(f"[ImageEditJSONLDataset] Error @ {idx}: {e} from {self.data_path}")
646
+ return self._retry()
647
+
648
+
649
+
src/datasets/understanding/caption_datasets.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ import os
4
+ import io
5
+ import json
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ from einops import rearrange
10
+ try:
11
+ from aoss_client.client import Client
12
+ except:
13
+ try:
14
+ from petrel_client.client import Client
15
+ except:
16
+ Client = None
17
+ from glob import glob
18
+ from xtuner.registry import BUILDER
19
+ from xtuner.dataset.utils import expand2square
20
+ from src.datasets.utils import crop2square, encode_fn
21
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
22
+ from src.datasets.understanding.caption_prompts import dense_prompts, short_prompts
23
+ from typing import List, Dict, Any, Optional,Callable,Tuple
24
+
25
+
26
+ @BUILDER.register_module()
27
+ class CaptionDataset(Dataset):
28
+ def __init__(self,
29
+ data_path,
30
+ local_folder,
31
+ image_size,
32
+ ceph_folder=None,
33
+ ceph_config=None,
34
+ tokenizer=None,
35
+ template_map_fn=None,
36
+ max_length=2048,
37
+ min_image_size=80,
38
+ image_length=256,
39
+ pad_image=True,
40
+ brief=False,
41
+ cap_folder=None,
42
+ cap_source='caption',
43
+ ):
44
+ super().__init__()
45
+ self.data_path = data_path
46
+ self._load_data(data_path)
47
+ self.local_folder = local_folder
48
+ self.cap_folder = local_folder if cap_folder is None else cap_folder
49
+ self.cap_source = cap_source
50
+
51
+ self.image_size = image_size
52
+
53
+ self.tokenizer = BUILDER.build(tokenizer)
54
+ self.prompt_template = template_map_fn['template']
55
+ self.template_map_fn = BUILDER.build(template_map_fn)
56
+ self.max_length = max_length
57
+ self.image_length = image_length
58
+ self.pad_image = pad_image
59
+ self.min_image_size = min_image_size
60
+
61
+ self.FILE_CLIENT = None
62
+ self.ceph_folder = ceph_folder
63
+ self.ceph_config = ceph_config
64
+ self.use_ceph = ((Client is not None) and (ceph_folder is not None)
65
+ and (ceph_config is not None) and os.path.exists(ceph_config))
66
+
67
+ self.brief = brief
68
+ self.caption_prompts = short_prompts if self.brief else dense_prompts
69
+
70
+ def _load_data(self, data_path: str): # image path and annotation path are saved in a json file
71
+ if data_path.endswith('.json'):
72
+ with open(data_path, 'r') as f:
73
+ self.data_list = json.load(f)
74
+ else:
75
+ json_files = glob(f"{data_path}/*.json")
76
+ data_list = []
77
+ for json_file in json_files:
78
+ with open(json_file, 'r') as f:
79
+ data_list += json.load(f)
80
+
81
+ self.data_list = data_list
82
+
83
+ print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
84
+
85
+ def __len__(self):
86
+ return len(self.data_list)
87
+
88
+ def _read_ceph(self, ceph_path):
89
+ if self.FILE_CLIENT is None:
90
+ self.FILE_CLIENT = Client(self.ceph_config)
91
+ data_bytes = self.FILE_CLIENT.get(ceph_path)
92
+
93
+ return io.BytesIO(data_bytes)
94
+
95
+ def _read_image(self, image_file):
96
+ if self.use_ceph:
97
+ image = Image.open(
98
+ self._read_ceph(
99
+ os.path.join(self.ceph_folder, image_file)
100
+ )
101
+ )
102
+ else:
103
+ image = Image.open(
104
+ os.path.join(self.local_folder, image_file)
105
+ )
106
+ assert image.width > self.min_image_size and image.height > self.min_image_size, f"Image: {image.size}"
107
+ assert image.width / image.height > 0.1, f"Image: {image.size}"
108
+ assert image.width / image.height < 10, f"Image: {image.size}"
109
+ return image.convert('RGB')
110
+
111
+ def _read_json(self, annotation_file):
112
+ if self.use_ceph:
113
+ annotation = json.load(
114
+ self._read_ceph(
115
+ os.path.join(self.ceph_folder, annotation_file)
116
+ )
117
+ )
118
+ else:
119
+ with open(os.path.join(self.local_folder, annotation_file), 'r') as f:
120
+ annotation = json.load(f)
121
+
122
+ return annotation
123
+
124
+ def _process_image(self, image):
125
+ data = dict()
126
+ if self.pad_image:
127
+ image = expand2square(image, (127, 127, 127))
128
+ else:
129
+ image = crop2square(image)
130
+
131
+ image = image.resize(size=(self.image_size, self.image_size))
132
+ pixel_values = torch.from_numpy(np.array(image)).float()
133
+ pixel_values = pixel_values / 255
134
+ pixel_values = 2 * pixel_values - 1
135
+ pixel_values = rearrange(pixel_values, 'h w c -> c h w')
136
+
137
+ data.update(pixel_values=pixel_values)
138
+ return data
139
+
140
+ def _process_text(self, text):
141
+ assert DEFAULT_IMAGE_TOKEN not in text, text
142
+ data_dict = dict(conversation=[{'input': f"{DEFAULT_IMAGE_TOKEN}\n{random.choice(self.caption_prompts)}",
143
+ 'output': text.strip()}])
144
+ data_dict.update(self.template_map_fn(data_dict))
145
+ data_dict.update(encode_fn(data_dict, self.tokenizer, self.max_length,
146
+ self.image_length, True, True))
147
+
148
+ assert (torch.tensor(data_dict['input_ids']).long() == IMAGE_TOKEN_INDEX).sum() == self.image_length, \
149
+ "Error in image format"
150
+
151
+ data_dict['type'] = 'image2text'
152
+ return data_dict
153
+
154
+ def _retry(self):
155
+ return self.__getitem__(random.choice(range(self.__len__())))
156
+
157
+ def __getitem__(self, idx):
158
+ try:
159
+ data_sample = self.data_list[idx]
160
+ image = self._read_image(data_sample['image']).convert('RGB')
161
+ data = self._process_image(image)
162
+ del image
163
+ with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f:
164
+ caption = json.load(f)[self.cap_source]
165
+ data.update(self._process_text(caption))
166
+
167
+ data.update(image_dir=self.local_folder, image_file=data_sample['image'])
168
+
169
+ return data
170
+
171
+ except Exception as e:
172
+ print(f"Error when reading {self.data_path}:{data_sample['image']}: {e}", flush=True)
173
+ return self._retry()
174
+
175
+
176
+ @BUILDER.register_module()
177
+ class VqaDataset(Dataset):
178
+ """Generic VQA / multimodal conversation dataset with robust IO & validation."""
179
+ # ---------- 初始化 ----------
180
+ def __init__(
181
+ self,
182
+ data_path: str,
183
+ tokenizer, # ← 必填参数,放在最前
184
+ template_map_fn: Callable, # ← 必填参数,放在最前
185
+ img_prefix: Optional[str] = None,
186
+ image_size: int = 512,
187
+ max_length: int = 2048,
188
+ image_length: int = 1089,
189
+ pad_image: bool = True,
190
+ min_image_size: int = 80,
191
+ image_token_patterns: Tuple[str, ...] = ('<image>', '[image]', '<img>'),
192
+ max_retry: int = 5,
193
+ ):
194
+ super().__init__()
195
+
196
+ self.img_prefix = img_prefix.rstrip("/") if img_prefix else None
197
+ self.image_size = image_size
198
+ self.max_length = max_length
199
+ self.image_length = image_length
200
+ self.pad_image = pad_image
201
+ self.min_image_size = min_image_size
202
+ self.image_token_patterns = list(image_token_patterns)
203
+ self.max_retry = max_retry
204
+
205
+ # 构建 tokenizer 与模板
206
+ self.tokenizer = BUILDER.build(tokenizer)
207
+ self.template_map_fn = BUILDER.build(template_map_fn) if template_map_fn else None
208
+
209
+ # 读取 jsonl / 目录
210
+ self.data_list = self._load_jsonl_list(data_path)
211
+ print(f"Loaded {len(self.data_list)} samples from {data_path}")
212
+
213
+ # ---------- 数据加载辅助 ----------
214
+ @staticmethod
215
+ def _load_jsonl_list(path: str) -> List[Dict[str, Any]]:
216
+ data: List[Dict[str, Any]] = []
217
+ if path.endswith(".jsonl"):
218
+ files = [path]
219
+ else:
220
+ files = sorted(glob(os.path.join(path, "**/*.jsonl"), recursive=True))
221
+
222
+ for file in files:
223
+ with open(file, "r") as f:
224
+ for line in f:
225
+ line = line.strip()
226
+ if line:
227
+ data.append(json.loads(line))
228
+ return data
229
+
230
+ # ---------- 基本接口 ----------
231
+ def __len__(self) -> int:
232
+ return len(self.data_list)
233
+
234
+ # ---------- 图像处理 ----------
235
+ def _get_image_path(self, img_file: str) -> str:
236
+ """保持绝对路径不变,否则加前缀"""
237
+ return img_file if os.path.isabs(img_file) else os.path.join(self.img_prefix, img_file)
238
+
239
+ def _read_image(self, img_file: str) -> Image.Image:
240
+ img_path = self._get_image_path(img_file)
241
+ try:
242
+ image = Image.open(img_path).convert("RGB")
243
+ except Exception as e:
244
+ raise FileNotFoundError(f"Cannot open image: {img_path} ({e})")
245
+
246
+ w, h = image.size
247
+ if w < self.min_image_size or h < self.min_image_size:
248
+ raise ValueError(f"Image too small: {img_path} ({w}x{h})")
249
+ ratio = w / h
250
+ if not (0.1 < ratio < 10):
251
+ raise ValueError(f"Odd aspect ratio ({ratio:.3f}) for {img_path}")
252
+
253
+ # pad / crop
254
+ image = expand2square(image, (127, 127, 127)) if self.pad_image else crop2square(image)
255
+ image = image.resize((self.image_size, self.image_size), resample=Image.BICUBIC)
256
+
257
+ px = torch.from_numpy(np.asarray(image)).float() / 255.0
258
+ px = 2 * px - 1.0
259
+ px = rearrange(px, "h w c -> c h w") # CHW
260
+ return px
261
+
262
+ # ---------- 对话处理 ----------
263
+ def _replace_image_tokens(self, txt: str) -> str:
264
+ for pat in self.image_token_patterns:
265
+ if pat in txt:
266
+ txt = txt.replace(pat, str(self.image_token_idx))
267
+ return txt
268
+
269
+ def _format_conversation(self, turns: List[Dict[str, str]]) -> Dict[str, Any]:
270
+ """
271
+ 将多个 human/gpt 轮次合并为若干 {'input':..., 'output':...} 对。
272
+ 遵循:human → gpt 为一对;若缺失 reply,用占位符。
273
+ """
274
+ pairs = []
275
+
276
+ for i in range(0, len(turns), 2): # 每两回合一对,human 和 gpt
277
+ if i + 1 < len(turns): # 确保 gpt turn 存在
278
+ human_turn = turns[i]
279
+ gpt_turn = turns[i + 1]
280
+
281
+ human_content = human_turn.get("value", "").strip()
282
+ gpt_content = gpt_turn.get("value", "").strip()
283
+
284
+ if not human_content.lstrip().startswith("<image>"):
285
+ human_content = f"<image>\n{human_content}"
286
+
287
+ if not human_content or not gpt_content: # 如果某一方没有内容,跳过该对话
288
+ continue
289
+
290
+ # 只在 human turn 中加入图像 token
291
+ # human_content = self._replace_image_tokens(human_content) # 替换成 image_token_idx
292
+
293
+ pairs.append({"input": human_content, "output": gpt_content})
294
+
295
+ data_dict = {"conversation": pairs}
296
+ data_dict_ori = data_dict
297
+ if self.template_map_fn:
298
+ data_dict = self.template_map_fn(data_dict)
299
+
300
+ # 对输入进行编码
301
+ data_dict = encode_fn(
302
+ data_dict,
303
+ self.tokenizer,
304
+ self.max_length,
305
+ self.image_length,
306
+ input_ids_with_output=True,
307
+ with_image_token=True,
308
+ # 额外把 image_token_idx 传进去
309
+ image_token_idx=self.image_token_idx
310
+ )
311
+
312
+ # 动态校验:确保至少出现一次图像 token
313
+ img_tokens = (torch.tensor(data_dict["input_ids"]) == self.image_token_idx).sum().item()
314
+
315
+ # 使用f-string优化打印格式,确保输出类型安全
316
+ print(f"[校验日志] input_ids长度: {len(data_dict['input_ids'])}, 图像token出现次数: {img_tokens}\n")
317
+ # print(f"[校验日志] input_ids: {data_dict.get('input_ids', '未设置')}\n")
318
+ if img_tokens != 1088:
319
+ print(f"[异常对话]:{data_dict_ori}")
320
+
321
+ data_dict["type"] = "image2text" # 设置数据类型为 image2text
322
+ return data_dict
323
+
324
+
325
+ # ---------- 主接口 ----------
326
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
327
+ for attempt in range(self.max_retry):
328
+ try:
329
+ sample = self.data_list[idx]
330
+ img_tensor = self._read_image(sample["image"])
331
+ text_data = self._format_conversation(sample.get("conversations", []))
332
+ return {
333
+ **text_data,
334
+ "pixel_values": img_tensor,
335
+ "image_file": sample["image"],
336
+ }
337
+ except Exception as e:
338
+ print(f"[Retry {attempt+1}/{self.max_retry}] idx={idx} error: {e}")
339
+ idx = random.randint(0, len(self) - 1)
340
+
341
+ # 若多次失败则抛异常
342
+ raise RuntimeError(f"Failed to fetch valid sample after {self.max_retry} retries.")
src/datasets/understanding/caption_prompts.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dense_prompts = [
2
+ "Describe the image in detail.",
3
+ "Provide a comprehensive description of everything you see in the picture.",
4
+ "Explain the scene depicted in the image as if you were describing it to someone who cannot see it.",
5
+ "List all the objects and activities taking place in this image.",
6
+ "What is the story being told by this image? Describe in detail.",
7
+ "Imagine you are giving a detailed tour of the image's scene. How would you describe it?",
8
+ "Describe the foreground, background, and any notable features of the image.",
9
+ "How would you describe this image to build a replica of the scene?",
10
+ "Write a paragraph detailing the setting, characters, and actions visible in this image.",
11
+ "Describe every aspect of the image, including the environment, objects, and any people present.",
12
+ "Provide a detailed analysis of the composition and elements of the image.",
13
+ "What are the main focal points of this image? Describe them in detail.",
14
+ "Catalog all visible elements in the image and describe their significance to the overall scene."
15
+ ]
16
+
17
+
18
+ short_prompts = [
19
+ "Briefly describe the image",
20
+ "Summarize the key elements of the image in one sentence.",
21
+ "Give a concise description of the scene.",
22
+ "Briefly, what is happening in this image?",
23
+ "What is the most noticeable feature of this image?",
24
+ "Summarize the image in a sentence.",
25
+ "What activity is being depicted in the image?",
26
+ "Describe the setting of the image in a few words.",
27
+ "Caption this image for a social media post."
28
+ ]
src/datasets/understanding/vlm_datasets_sig.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ import os
4
+ import json
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ from einops import rearrange
9
+ from xtuner.registry import BUILDER
10
+ from xtuner.dataset.utils import expand2square
11
+ from src.datasets.utils import crop2square, encode_fn, load_jsonl
12
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN
13
+ from transformers import AutoImageProcessor
14
+
15
+
16
+ class VLMDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ data_path,
20
+ image_size,
21
+ tokenizer=None,
22
+ template_map_fn=None,
23
+ max_length=2048,
24
+ min_image_size=80,
25
+ pad_image=True,
26
+ local_folder="",
27
+ key_value="conversations",
28
+ ):
29
+ super().__init__()
30
+ self.data_path = data_path
31
+ self._load_data(data_path)
32
+ self.image_size = image_size
33
+
34
+ self.tokenizer = BUILDER.build(tokenizer)
35
+ self.prompt_template = template_map_fn["template"]
36
+ self.template_map_fn = BUILDER.build(template_map_fn)
37
+ self.max_length = max_length
38
+ self.pad_image = pad_image
39
+ self.min_image_size = min_image_size
40
+ self.key_value = key_value
41
+ self.processor = AutoImageProcessor.from_pretrained(
42
+ "checkpoint/siglip2-so400m-patch16-512"
43
+ )
44
+ self.metainfo = {'task' :'unified'}
45
+ self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
46
+ m = n = self.image_size // 16
47
+ self.image_token_repeat = m * n + 64
48
+
49
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
50
+ self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
51
+ print(f"Registered <image> token at index {self.image_token_idx}")
52
+
53
+ def _load_data(
54
+ self, data_path: str
55
+ ): # image path and annotation path are saved in a json file
56
+ self.data_list = load_jsonl(data_path)
57
+ print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
58
+
59
+ def full_init(self):
60
+ """Dummy full_init to be compatible with MMEngine ConcatDataset."""
61
+ return
62
+ def __len__(self):
63
+ return len(self.data_list)
64
+
65
+ def _read_image(self, image_file):
66
+ image = Image.open(image_file)
67
+ assert (
68
+ image.width > self.min_image_size and image.height > self.min_image_size
69
+ ), f"Image: {image.size}"
70
+ assert image.width / image.height > 0.1, f"Image: {image.size}"
71
+ assert image.width / image.height < 10, f"Image: {image.size}"
72
+ return image.convert("RGB")
73
+
74
+ # def _process_image(self, image):
75
+ # data = dict()
76
+ # # if self.pad_image:
77
+ # # image = expand2square(image, (127, 127, 127))
78
+ # # else:
79
+ # # image = crop2square(image)
80
+
81
+ # # image = image.resize(size=(self.image_size, self.image_size))
82
+ # # pixel_values = torch.from_numpy(np.array(image)).float()
83
+ # # pixel_values = pixel_values / 255
84
+ # # pixel_values = 2 * pixel_values - 1
85
+ # # pixel_values = rearrange(pixel_values, "h w c -> c h w")
86
+ # image = image.resize((self.image_size, self.image_size))
87
+ # inputs = self.processor(images=image, return_tensors="pt")
88
+ # pixel_values = inputs["pixel_values"].squeeze(0)
89
+
90
+ # data.update(pixel_values=pixel_values)
91
+ # return data
92
+
93
+
94
+ def _process_image(self, image: Image.Image):
95
+ # 1) 可选 crop/pad to square
96
+ if self.pad_image:
97
+ image = crop2square(image)
98
+ # 2) 手动 resize 到指定大小
99
+ image = image.resize((self.image_size, self.image_size))
100
+ # 3) to tensor & normalize
101
+ arr = np.array(image).astype(np.float32) / 255.0 # HWC
102
+ arr = 2 * arr - 1 # [-1,1]
103
+ tensor = torch.from_numpy(arr) # HWC
104
+ tensor = rearrange(tensor, "h w c -> c h w") # CHW
105
+ return {"pixel_values": tensor}
106
+ def _process_text(self, question, answer):
107
+ data_dict = dict(
108
+ conversation=[
109
+ {
110
+ "input": f"{self.DEFAULT_IMAGE_TOKEN}\n{question}",
111
+ "output": answer,
112
+ }
113
+ ]
114
+ )
115
+ data_dict.update(self.template_map_fn(data_dict))
116
+ data_dict.update(
117
+ encode_fn(
118
+ example=data_dict,
119
+ tokenizer=self.tokenizer,
120
+ max_length=self.max_length,
121
+ image_length=self.image_token_repeat,
122
+ input_ids_with_output=True,
123
+ with_image_token=True,
124
+ truncation='right',
125
+ image_token_idx=self.image_token_idx,
126
+ image_token_str=self.DEFAULT_IMAGE_TOKEN,
127
+ )
128
+ )
129
+
130
+ # assert (
131
+ # torch.tensor(data_dict["input_ids"]).long() == self.image_token_idx
132
+ # ).sum() == self.image_length, "Error in image format"
133
+
134
+ data_dict["type"] = "image2text"
135
+ return data_dict
136
+
137
+ def _retry(self):
138
+ return self.__getitem__(random.choice(range(self.__len__())))
139
+
140
+ def __getitem__(self, idx):
141
+ try:
142
+ data_sample = self.data_list[idx]
143
+ image = self._read_image(data_sample["image"]).convert("RGB")
144
+ data = self._process_image(image)
145
+ del image
146
+ question = (
147
+ data_sample[self.key_value][0]["value"]
148
+ .replace("<image>", "")
149
+ .strip()
150
+ )
151
+ answer = (
152
+ data_sample[self.key_value][1]["value"]
153
+ .replace("<image>", "")
154
+ .strip()
155
+ )
156
+
157
+ data.update(self._process_text(question, answer))
158
+
159
+ data.update(image_file=data_sample["image"])
160
+
161
+ return data
162
+
163
+ except Exception as e:
164
+ print(
165
+ f"Error when reading data_sample:{data_sample},{self.data_path}:{data_sample['image']}: {e}",
166
+ flush=True,
167
+ )
168
+ return self._retry()
src/datasets/utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from xtuner.dataset.utils import get_bos_eos_token_ids
4
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
5
+ import json
6
+
7
+
8
+ # def crop2square(pil_img):
9
+ # width, height = pil_img.width, pil_img.height
10
+
11
+ # if width > height:
12
+ # y0, y1 = 0, height
13
+ # x0 = random.randint(0, width - height) # [0, w - h]
14
+ # x1 = x0 + height # [h, w]
15
+ # else:
16
+ # x0, x1 = 0, width
17
+ # y0 = random.randint(0, height - width) # [0, h - w]
18
+ # y1 = y0 + width # [w, h]
19
+
20
+ # return pil_img.crop(box=(x0, y0, x1, y1))
21
+
22
+ def crop2square(pil_img):
23
+ width, height = pil_img.width, pil_img.height
24
+ short = min(width, height)
25
+ left = (width - short) // 2
26
+ upper = (height - short) // 2
27
+ return pil_img.crop((left, upper, left + short, upper + short))
28
+ def load_jsonl(json_file):
29
+ with open(json_file) as f:
30
+ lines = f.readlines()
31
+ data = []
32
+ for line in lines:
33
+ data.append(json.loads(line))
34
+ return data
35
+
36
+
37
+ def encode_fn_original(example,
38
+ tokenizer,
39
+ max_length=None,
40
+ image_length=1,
41
+ input_ids_with_output=True,
42
+ with_image_token=False,
43
+ truncation='right',
44
+ image_token_idx=None,
45
+ image_token_str="<image>"):
46
+ """We only support the following three scenarios:
47
+
48
+ 1. Incremental pretraining dataset.
49
+ example['conversation'] = [
50
+ {
51
+ 'input': '',
52
+ 'output': '### Human: Can you write xxx'
53
+ }
54
+ ]
55
+
56
+ 2. Single-turn conversation dataset.
57
+ example['conversation'] = [
58
+ {
59
+ 'input': 'Give three tips for staying healthy.',
60
+ 'output': '1.Eat a balanced diet xxx'
61
+ }
62
+ ]
63
+
64
+ 3. Multi-turn conversation dataset.
65
+ example['conversation'] = [
66
+ {
67
+ 'input': 'Give three tips for staying healthy.',
68
+ 'output': '1.Eat a balanced diet xxx'
69
+ },
70
+ {
71
+ 'input': 'Please expand on the second point.',
72
+ 'output': 'Here is an expanded explanation of the xxx'
73
+ }
74
+ ]
75
+ """
76
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
77
+ if image_token_idx is None: # 如果没传,就退回库常量
78
+ image_token_idx = tokenizer.convert_tokens_to_ids("<image>")
79
+
80
+ is_multi_turn_conversation = len(example['conversation']) > 1
81
+ if is_multi_turn_conversation:
82
+ assert input_ids_with_output
83
+
84
+ input_ids, labels = [], []
85
+ next_needs_bos_token = True
86
+ for single_turn_conversation in example['conversation']:
87
+ input = single_turn_conversation['input']
88
+ if image_token_str in input and with_image_token:
89
+ chunk_encode = [
90
+ tokenizer.encode(chunk, add_special_tokens=False)
91
+ for chunk in input.split(image_token_str)
92
+ ]
93
+ assert len(chunk_encode) == 2
94
+ input_encode = []
95
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
96
+ input_encode.extend(cur_chunk_encode)
97
+ if idx != len(chunk_encode) - 1:
98
+ # input_encode.append(IMAGE_TOKEN_INDEX)
99
+ input_encode += [image_token_idx] * image_length
100
+
101
+ else:
102
+ input_encode = tokenizer.encode(input, add_special_tokens=False)
103
+ if next_needs_bos_token:
104
+ input_ids += bos_token_id
105
+ labels += [IGNORE_INDEX] * len(bos_token_id)
106
+ input_ids += input_encode
107
+ labels += [IGNORE_INDEX] * len(input_encode)
108
+ if input_ids_with_output and 'output' in single_turn_conversation:
109
+ # Add output
110
+ output_with_loss = single_turn_conversation.get(
111
+ 'output_with_loss', True)
112
+ output = single_turn_conversation['output']
113
+
114
+ if image_token_str in output and with_image_token:
115
+ chunk_encode = [
116
+ tokenizer.encode(chunk, add_special_tokens=False)
117
+ for chunk in output.split(image_token_str)
118
+ ]
119
+ assert len(chunk_encode) == 2
120
+ output_encode = []
121
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
122
+ output_encode.extend(cur_chunk_encode)
123
+ if idx != len(chunk_encode) - 1:
124
+ output_encode += [image_token_idx] * image_length
125
+ else:
126
+ output_encode = tokenizer.encode(output, add_special_tokens=False)
127
+ # output_encode = tokenizer.encode(output, add_special_tokens=False)
128
+ input_ids += output_encode
129
+ if output_with_loss:
130
+ labels += copy.deepcopy(output_encode)
131
+ else:
132
+ labels += [IGNORE_INDEX] * len(output_encode)
133
+ # Add EOS_TOKEN (with loss)
134
+ if single_turn_conversation.get('need_eos_token', True):
135
+ next_needs_bos_token = True
136
+ input_ids += eos_token_id
137
+ if output_with_loss:
138
+ labels += copy.deepcopy(eos_token_id)
139
+ else:
140
+ labels += [IGNORE_INDEX] * len(eos_token_id)
141
+ else:
142
+ next_needs_bos_token = False
143
+ # Add SEP (without loss)
144
+ sep = single_turn_conversation.get('sep', '')
145
+ if sep != '':
146
+ sep_encode = tokenizer.encode(sep, add_special_tokens=False)
147
+ input_ids += sep_encode
148
+ labels += [IGNORE_INDEX] * len(sep_encode)
149
+
150
+ if max_length is not None and len(input_ids) > max_length:
151
+ if truncation == 'right':
152
+ input_ids = input_ids[:max_length]
153
+ labels = labels[:max_length]
154
+ elif truncation == 'left':
155
+ input_ids = input_ids[-max_length:]
156
+ labels = labels[-max_length:]
157
+ else:
158
+ assert truncation is None
159
+ return {'input_ids': input_ids, 'labels': labels}
160
+
161
+
162
+
163
+ def encode_fn(
164
+ example,
165
+ tokenizer,
166
+ prompt_template=None,
167
+ max_length=None,
168
+ image_length=1,
169
+ input_ids_with_output=True,
170
+ with_image_token=True,
171
+ truncation='right',
172
+ image_token_idx=None,
173
+ image_token_str="<image>",
174
+ ):
175
+ """
176
+ A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks.
177
+
178
+ - Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels.
179
+ - Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX).
180
+ """
181
+ # assert image_token_idx is not None, "Must pass image_token_idx explicitly"
182
+ # print(f"[DEBUG] image_token_idx = {image_token_idx}")
183
+ if image_token_idx is None:
184
+ tokenizer.add_tokens([image_token_str], special_tokens=True)
185
+ image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str)
186
+
187
+ if isinstance(example, str):
188
+ assert prompt_template is not None, \
189
+ "prompt_template 不能为空(text2image/image-editing)"
190
+
191
+ # 1) 构造 prompt
192
+ # 直接在最前面加一个 <image> token,
193
+ # 然后空一行,再拼原始文本
194
+ prompt = f"{example.strip()}"
195
+ # 用模板包装
196
+ prompt = prompt_template["INSTRUCTION"].format(input=prompt)
197
+
198
+ # 2) 用 tokenizer 编码(不要让 tokenizer 把 <image> 当成普通字符切分)
199
+ # 一种简单做法:先去掉 tokenizer 里的特殊 token,再手动拼接
200
+ text_ids = tokenizer.encode(
201
+ prompt,
202
+ add_special_tokens=False,
203
+ truncation=True,
204
+ max_length=(max_length - image_length) if max_length else None
205
+ )
206
+ # 把 <image> token id 插到最前面(或者你想要的位置)
207
+ input_ids = [image_token_idx] * image_length + text_ids
208
+
209
+ # 3) 如果超长,直接截断
210
+ if max_length is not None and len(input_ids) > max_length:
211
+ input_ids = input_ids[:max_length]
212
+
213
+ # 4) attention_mask
214
+ attention_mask = [1] * len(input_ids)
215
+
216
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
217
+
218
+ # --- Image-to-text task: multi-turn conversation structure ---
219
+ assert isinstance(example, dict) and "conversation" in example
220
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
221
+ is_multi_turn = len(example["conversation"]) > 1
222
+ if is_multi_turn:
223
+ assert input_ids_with_output
224
+
225
+ input_ids, labels = [], []
226
+ next_needs_bos_token = True
227
+
228
+ for single_turn in example["conversation"]:
229
+ input_text = single_turn["input"]
230
+
231
+ # ==== Encode input ====
232
+ if with_image_token and image_token_str in input_text:
233
+ chunks = input_text.split(image_token_str)
234
+ chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
235
+ assert len(chunk_encoded) >= 2
236
+ input_encode = []
237
+ for i, chunk in enumerate(chunk_encoded):
238
+ input_encode.extend(chunk)
239
+ if i < len(chunk_encoded) - 1:
240
+ input_encode.extend([image_token_idx] * image_length)
241
+ else:
242
+ input_encode = tokenizer.encode(input_text, add_special_tokens=False)
243
+
244
+ if next_needs_bos_token:
245
+ input_ids.extend(bos_token_id)
246
+ labels.extend([IGNORE_INDEX] * len(bos_token_id))
247
+
248
+ input_ids.extend(input_encode)
249
+ labels.extend([IGNORE_INDEX] * len(input_encode))
250
+
251
+ # ==== Encode output ====
252
+ if input_ids_with_output and "output" in single_turn:
253
+ output = single_turn["output"]
254
+ output_with_loss = single_turn.get("output_with_loss", True)
255
+
256
+ if with_image_token and image_token_str in output:
257
+ chunks = output.split(image_token_str)
258
+ chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
259
+ assert len(chunk_encoded) >= 2
260
+ output_encode = []
261
+ for i, chunk in enumerate(chunk_encoded):
262
+ output_encode.extend(chunk)
263
+ if i < len(chunk_encoded) - 1:
264
+ output_encode.extend([image_token_idx] * image_length)
265
+ else:
266
+ output_encode = tokenizer.encode(output, add_special_tokens=False)
267
+
268
+ input_ids.extend(output_encode)
269
+ if output_with_loss:
270
+ labels.extend(output_encode.copy())
271
+ else:
272
+ labels.extend([IGNORE_INDEX] * len(output_encode))
273
+
274
+ # ==== Append EOS ====
275
+ if single_turn.get("need_eos_token", True):
276
+ next_needs_bos_token = True
277
+ input_ids.extend(eos_token_id)
278
+ if output_with_loss:
279
+ labels.extend(eos_token_id.copy())
280
+ else:
281
+ labels.extend([IGNORE_INDEX] * len(eos_token_id))
282
+ else:
283
+ next_needs_bos_token = False
284
+
285
+ # ==== Append separator ====
286
+ sep = single_turn.get("sep", "")
287
+ if sep:
288
+ sep_encoded = tokenizer.encode(sep, add_special_tokens=False)
289
+ input_ids.extend(sep_encoded)
290
+ labels.extend([IGNORE_INDEX] * len(sep_encoded))
291
+
292
+ # ==== Truncation ====
293
+ if max_length is not None and len(input_ids) > max_length:
294
+ if truncation == "right":
295
+ input_ids = input_ids[:max_length]
296
+ labels = labels[:max_length]
297
+ elif truncation == "left":
298
+ input_ids = input_ids[-max_length:]
299
+ labels = labels[-max_length:]
300
+ else:
301
+ raise ValueError("truncation must be 'left', 'right', or None")
302
+
303
+ return {"input_ids": input_ids, "labels": labels}
src/models/mar/decoder.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ from timm.models.vision_transformer import Block
5
+ from functools import partial
6
+
7
+
8
+ class MARDecoder(nn.Module):
9
+ """ Masked Autoencoder with VisionTransformer backbone
10
+ """
11
+ def __init__(self, img_size=256, vae_stride=16,
12
+ patch_size=1,
13
+ # encoder_embed_dim=1024,
14
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
15
+ mlp_ratio=4.,
16
+ attn_dropout=0.1,
17
+ proj_dropout=0.1,
18
+ buffer_size=64,
19
+ grad_checkpointing=False,
20
+ ):
21
+ super().__init__()
22
+
23
+ # --------------------------------------------------------------------------
24
+ # VAE
25
+ self.img_size = img_size
26
+ self.vae_stride = vae_stride
27
+
28
+ self.seq_h = self.seq_w = img_size // vae_stride // patch_size
29
+ self.seq_len = self.seq_h * self.seq_w
30
+
31
+ self.grad_checkpointing = grad_checkpointing
32
+
33
+ # --------------------------------------------------------------------------
34
+ # MAR decoder specifics
35
+ self.buffer_size = buffer_size
36
+ # self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
37
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
38
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
39
+ self.decoder_blocks = nn.ModuleList([
40
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True,
41
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
42
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
43
+
44
+ self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
45
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
46
+
47
+ self.initialize_weights()
48
+
49
+ def initialize_weights(self):
50
+ # parameters
51
+
52
+ torch.nn.init.normal_(self.mask_token, std=.02)
53
+
54
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
55
+ torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
56
+
57
+ # initialize nn.Linear and nn.LayerNorm
58
+ self.apply(self._init_weights)
59
+
60
+ def _init_weights(self, m):
61
+ if isinstance(m, nn.Linear):
62
+ # we use xavier_uniform following official JAX ViT:
63
+ torch.nn.init.xavier_uniform_(m.weight)
64
+ if isinstance(m, nn.Linear) and m.bias is not None:
65
+ nn.init.constant_(m.bias, 0)
66
+ elif isinstance(m, nn.LayerNorm):
67
+ if m.bias is not None:
68
+ nn.init.constant_(m.bias, 0)
69
+ if m.weight is not None:
70
+ nn.init.constant_(m.weight, 1.0)
71
+
72
+ def forward(self, x, mask):
73
+
74
+ # x = self.decoder_embed(x)
75
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
76
+
77
+ # pad mask tokens
78
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
79
+ x_after_pad = mask_tokens.clone()
80
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
81
+
82
+ # decoder position embedding
83
+ x = x_after_pad + self.decoder_pos_embed_learned
84
+
85
+ # apply Transformer blocks
86
+ if self.grad_checkpointing and not torch.jit.is_scripting():
87
+ for block in self.decoder_blocks:
88
+ x = checkpoint(block, x)
89
+ else:
90
+ for block in self.decoder_blocks:
91
+ x = block(x)
92
+ x = self.decoder_norm(x)
93
+
94
+ x = x[:, self.buffer_size:]
95
+ x = x + self.diffusion_pos_embed_learned
96
+ return x
97
+
98
+ def gradient_checkpointing_enable(self):
99
+ self.grad_checkpointing = True
100
+
101
+ def gradient_checkpointing_disable(self):
102
+ self.grad_checkpointing = False
src/models/mar/diffloss.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ import math
5
+
6
+ from src.models.mar.diffusion import create_diffusion
7
+
8
+
9
+ class DiffLoss(nn.Module):
10
+ """Diffusion Loss"""
11
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12
+ super(DiffLoss, self).__init__()
13
+ self.in_channels = target_channels
14
+ self.net = SimpleMLPAdaLN(
15
+ in_channels=target_channels,
16
+ model_channels=width,
17
+ out_channels=target_channels * 2, # for vlb loss
18
+ z_channels=z_channels,
19
+ num_res_blocks=depth,
20
+ grad_checkpointing=grad_checkpointing
21
+ )
22
+
23
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25
+
26
+ def forward(self, target, z, mask=None):
27
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28
+ model_kwargs = dict(c=z)
29
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30
+ loss = loss_dict["loss"]
31
+ if mask is not None:
32
+ loss = (loss * mask).sum() / mask.sum()
33
+ return loss.mean()
34
+
35
+ def sample(self, z, temperature=1.0, cfg=1.0):
36
+ # diffusion loss sampling
37
+ if not cfg == 1.0:
38
+ noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39
+ noise = torch.cat([noise, noise], dim=0)
40
+ model_kwargs = dict(c=z, cfg_scale=cfg)
41
+ sample_fn = self.net.forward_with_cfg
42
+ else:
43
+ noise = torch.randn(z.shape[0], self.in_channels).cuda()
44
+ model_kwargs = dict(c=z)
45
+ sample_fn = self.net.forward
46
+
47
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
48
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49
+ temperature=temperature
50
+ )
51
+
52
+ return sampled_token_latent
53
+
54
+
55
+ def modulate(x, shift, scale):
56
+ return x * (1 + scale) + shift
57
+
58
+
59
+ class TimestepEmbedder(nn.Module):
60
+ """
61
+ Embeds scalar timesteps into vector representations.
62
+ """
63
+ def __init__(self, hidden_size, frequency_embedding_size=256):
64
+ super().__init__()
65
+ self.mlp = nn.Sequential(
66
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size, bias=True),
69
+ )
70
+ self.frequency_embedding_size = frequency_embedding_size
71
+
72
+ @staticmethod
73
+ def timestep_embedding(t, dim, max_period=10000):
74
+ """
75
+ Create sinusoidal timestep embeddings.
76
+ :param t: a 1-D Tensor of N indices, one per batch element.
77
+ These may be fractional.
78
+ :param dim: the dimension of the output.
79
+ :param max_period: controls the minimum frequency of the embeddings.
80
+ :return: an (N, D) Tensor of positional embeddings.
81
+ """
82
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83
+ half = dim // 2
84
+ freqs = torch.exp(
85
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86
+ ).to(device=t.device)
87
+ args = t[:, None].float() * freqs[None]
88
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89
+ if dim % 2:
90
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91
+ return embedding
92
+
93
+ def forward(self, t):
94
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.data.dtype))
96
+ return t_emb
97
+
98
+
99
+ class ResBlock(nn.Module):
100
+ """
101
+ A residual block that can optionally change the number of channels.
102
+ :param channels: the number of input channels.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ channels
108
+ ):
109
+ super().__init__()
110
+ self.channels = channels
111
+
112
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113
+ self.mlp = nn.Sequential(
114
+ nn.Linear(channels, channels, bias=True),
115
+ nn.SiLU(),
116
+ nn.Linear(channels, channels, bias=True),
117
+ )
118
+
119
+ self.adaLN_modulation = nn.Sequential(
120
+ nn.SiLU(),
121
+ nn.Linear(channels, 3 * channels, bias=True)
122
+ )
123
+
124
+ def forward(self, x, y):
125
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127
+ h = self.mlp(h)
128
+ return x + gate_mlp * h
129
+
130
+
131
+ class FinalLayer(nn.Module):
132
+ """
133
+ The final layer adopted from DiT.
134
+ """
135
+ def __init__(self, model_channels, out_channels):
136
+ super().__init__()
137
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
139
+ self.adaLN_modulation = nn.Sequential(
140
+ nn.SiLU(),
141
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
142
+ )
143
+
144
+ def forward(self, x, c):
145
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146
+ x = modulate(self.norm_final(x), shift, scale)
147
+ x = self.linear(x)
148
+ return x
149
+
150
+
151
+ class SimpleMLPAdaLN(nn.Module):
152
+ """
153
+ The MLP for Diffusion Loss.
154
+ :param in_channels: channels in the input Tensor.
155
+ :param model_channels: base channel count for the model.
156
+ :param out_channels: channels in the output Tensor.
157
+ :param z_channels: channels in the condition.
158
+ :param num_res_blocks: number of residual blocks per downsample.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ model_channels,
165
+ out_channels,
166
+ z_channels,
167
+ num_res_blocks,
168
+ grad_checkpointing=False
169
+ ):
170
+ super().__init__()
171
+
172
+ self.in_channels = in_channels
173
+ self.model_channels = model_channels
174
+ self.out_channels = out_channels
175
+ self.num_res_blocks = num_res_blocks
176
+ self.grad_checkpointing = grad_checkpointing
177
+
178
+ self.time_embed = TimestepEmbedder(model_channels)
179
+ self.cond_embed = nn.Linear(z_channels, model_channels)
180
+
181
+ self.input_proj = nn.Linear(in_channels, model_channels)
182
+
183
+ res_blocks = []
184
+ for i in range(num_res_blocks):
185
+ res_blocks.append(ResBlock(
186
+ model_channels,
187
+ ))
188
+
189
+ self.res_blocks = nn.ModuleList(res_blocks)
190
+ self.final_layer = FinalLayer(model_channels, out_channels)
191
+
192
+ self.initialize_weights()
193
+
194
+ def initialize_weights(self):
195
+ def _basic_init(module):
196
+ if isinstance(module, nn.Linear):
197
+ torch.nn.init.xavier_uniform_(module.weight)
198
+ if module.bias is not None:
199
+ nn.init.constant_(module.bias, 0)
200
+ self.apply(_basic_init)
201
+
202
+ # Initialize timestep embedding MLP
203
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205
+
206
+ # Zero-out adaLN modulation layers
207
+ for block in self.res_blocks:
208
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210
+
211
+ # Zero-out output layers
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214
+ nn.init.constant_(self.final_layer.linear.weight, 0)
215
+ nn.init.constant_(self.final_layer.linear.bias, 0)
216
+
217
+ def forward(self, x, t, c):
218
+ """
219
+ Apply the model to an input batch.
220
+ :param x: an [N x C] Tensor of inputs.
221
+ :param t: a 1-D batch of timesteps.
222
+ :param c: conditioning from AR transformer.
223
+ :return: an [N x C] Tensor of outputs.
224
+ """
225
+ # import pdb; pdb.set_trace()
226
+ x = self.input_proj(x.to(self.input_proj.weight.data.dtype))
227
+ t = self.time_embed(t)
228
+ c = self.cond_embed(c.to(self.cond_embed.weight.data.dtype))
229
+
230
+ y = t + c
231
+
232
+ if self.grad_checkpointing and not torch.jit.is_scripting():
233
+ for block in self.res_blocks:
234
+ x = checkpoint(block, x, y)
235
+ else:
236
+ for block in self.res_blocks:
237
+ x = block(x, y)
238
+
239
+ return self.final_layer(x, y)
240
+
241
+ def forward_with_cfg(self, x, t, c, cfg_scale):
242
+ half = x[: len(x) // 2]
243
+ combined = torch.cat([half, half], dim=0)
244
+ model_out = self.forward(combined, t, c)
245
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
246
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
247
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
248
+ eps = torch.cat([half_eps, half_eps], dim=0)
249
+ return torch.cat([eps, rest], dim=1)
src/models/mar/diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from DiT, which is modified from OpenAI's diffusion repos
2
+ # DiT: https://github.com/facebookresearch/DiT/diffusion
3
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6
+
7
+ from . import gaussian_diffusion as gd
8
+ from .respace import SpacedDiffusion, space_timesteps
9
+
10
+
11
+ def create_diffusion(
12
+ timestep_respacing,
13
+ noise_schedule="highres_cosine",
14
+ use_kl=False,
15
+ sigma_small=False,
16
+ predict_xstart=False,
17
+ learn_sigma=True,
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
src/models/mar/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
50
+ given image.
51
+ :param x: the target images. It is assumed that this was uint8 values,
52
+ rescaled to the range [-1, 1].
53
+ :param means: the Gaussian mean Tensor.
54
+ :param log_scales: the Gaussian log stddev Tensor.
55
+ :return: a tensor like x of log probabilities (in nats).
56
+ """
57
+ assert x.shape == means.shape == log_scales.shape
58
+ centered_x = x - means
59
+ inv_stdv = th.exp(-log_scales)
60
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
61
+ cdf_plus = approx_standard_normal_cdf(plus_in)
62
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
63
+ cdf_min = approx_standard_normal_cdf(min_in)
64
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
65
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
66
+ cdf_delta = cdf_plus - cdf_min
67
+ log_probs = th.where(
68
+ x < -0.999,
69
+ log_cdf_plus,
70
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
71
+ )
72
+ assert log_probs.shape == x.shape
73
+ return log_probs
src/models/mar/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "cosine":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ elif schedule_name == "highres_cosine":
122
+ # Custom smoother cosine schedule for high-resolution diffusion
123
+ return betas_for_alpha_bar(
124
+ num_diffusion_timesteps,
125
+ lambda t: math.cos((t + 0.005) / 1.005 * math.pi / 2) ** 2,
126
+ max_beta=0.2, # conservative to avoid over-noising at high-res
127
+ )
128
+ else:
129
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
130
+
131
+
132
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.2):
133
+ """
134
+ Create a beta schedule that discretizes the given alpha_t_bar function,
135
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
136
+ :param num_diffusion_timesteps: the number of betas to produce.
137
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
138
+ produces the cumulative product of (1-beta) up to that
139
+ part of the diffusion process.
140
+ :param max_beta: the maximum beta to use; use values lower than 1 to
141
+ prevent singularities.
142
+ """
143
+ betas = []
144
+ for i in range(num_diffusion_timesteps):
145
+ t1 = i / num_diffusion_timesteps
146
+ t2 = (i + 1) / num_diffusion_timesteps
147
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
148
+ return np.array(betas)
149
+
150
+
151
+ class GaussianDiffusion:
152
+ """
153
+ Utilities for training and sampling diffusion models.
154
+ Original ported from this codebase:
155
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
156
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
157
+ starting at T and going to 1.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ *,
163
+ betas,
164
+ model_mean_type,
165
+ model_var_type,
166
+ loss_type
167
+ ):
168
+
169
+ self.model_mean_type = model_mean_type
170
+ self.model_var_type = model_var_type
171
+ self.loss_type = loss_type
172
+
173
+ # Use float64 for accuracy.
174
+ betas = np.array(betas, dtype=np.float64)
175
+ self.betas = betas
176
+ assert len(betas.shape) == 1, "betas must be 1-D"
177
+ assert (betas > 0).all() and (betas <= 1).all()
178
+
179
+ self.num_timesteps = int(betas.shape[0])
180
+
181
+ alphas = 1.0 - betas
182
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
183
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
184
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
185
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
186
+
187
+ # calculations for diffusion q(x_t | x_{t-1}) and others
188
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
189
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
190
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
191
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
192
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
193
+
194
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
195
+ self.posterior_variance = (
196
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
197
+ )
198
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
199
+ self.posterior_log_variance_clipped = np.log(
200
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
201
+ ) if len(self.posterior_variance) > 1 else np.array([])
202
+
203
+ self.posterior_mean_coef1 = (
204
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
205
+ )
206
+ self.posterior_mean_coef2 = (
207
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
208
+ )
209
+
210
+ def q_mean_variance(self, x_start, t):
211
+ """
212
+ Get the distribution q(x_t | x_0).
213
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
214
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
215
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
216
+ """
217
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
218
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
219
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
220
+ return mean, variance, log_variance
221
+
222
+ def q_sample(self, x_start, t, noise=None):
223
+ """
224
+ Diffuse the data for a given number of diffusion steps.
225
+ In other words, sample from q(x_t | x_0).
226
+ :param x_start: the initial data batch.
227
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
228
+ :param noise: if specified, the split-out normal noise.
229
+ :return: A noisy version of x_start.
230
+ """
231
+ if noise is None:
232
+ noise = th.randn_like(x_start)
233
+ assert noise.shape == x_start.shape
234
+ return (
235
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
236
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
237
+ )
238
+
239
+ def q_posterior_mean_variance(self, x_start, x_t, t):
240
+ """
241
+ Compute the mean and variance of the diffusion posterior:
242
+ q(x_{t-1} | x_t, x_0)
243
+ """
244
+ assert x_start.shape == x_t.shape
245
+ posterior_mean = (
246
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
247
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
248
+ )
249
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
250
+ posterior_log_variance_clipped = _extract_into_tensor(
251
+ self.posterior_log_variance_clipped, t, x_t.shape
252
+ )
253
+ assert (
254
+ posterior_mean.shape[0]
255
+ == posterior_variance.shape[0]
256
+ == posterior_log_variance_clipped.shape[0]
257
+ == x_start.shape[0]
258
+ )
259
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
260
+
261
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
262
+ """
263
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
264
+ the initial x, x_0.
265
+ :param model: the model, which takes a signal and a batch of timesteps
266
+ as input.
267
+ :param x: the [N x C x ...] tensor at time t.
268
+ :param t: a 1-D Tensor of timesteps.
269
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
270
+ :param denoised_fn: if not None, a function which applies to the
271
+ x_start prediction before it is used to sample. Applies before
272
+ clip_denoised.
273
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
274
+ pass to the model. This can be used for conditioning.
275
+ :return: a dict with the following keys:
276
+ - 'mean': the model mean output.
277
+ - 'variance': the model variance output.
278
+ - 'log_variance': the log of 'variance'.
279
+ - 'pred_xstart': the prediction for x_0.
280
+ """
281
+ if model_kwargs is None:
282
+ model_kwargs = {}
283
+
284
+ B, C = x.shape[:2]
285
+ assert t.shape == (B,)
286
+ model_output = model(x, t, **model_kwargs)
287
+ if isinstance(model_output, tuple):
288
+ model_output, extra = model_output
289
+ else:
290
+ extra = None
291
+
292
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
293
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
294
+ model_output, model_var_values = th.split(model_output, C, dim=1)
295
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
296
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
297
+ # The model_var_values is [-1, 1] for [min_var, max_var].
298
+ frac = (model_var_values + 1) / 2
299
+ model_log_variance = frac * max_log + (1 - frac) * min_log
300
+ model_variance = th.exp(model_log_variance)
301
+ else:
302
+ model_variance, model_log_variance = {
303
+ # for fixedlarge, we set the initial (log-)variance like so
304
+ # to get a better decoder log likelihood.
305
+ ModelVarType.FIXED_LARGE: (
306
+ np.append(self.posterior_variance[1], self.betas[1:]),
307
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
308
+ ),
309
+ ModelVarType.FIXED_SMALL: (
310
+ self.posterior_variance,
311
+ self.posterior_log_variance_clipped,
312
+ ),
313
+ }[self.model_var_type]
314
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
315
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
316
+
317
+ def process_xstart(x):
318
+ if denoised_fn is not None:
319
+ x = denoised_fn(x)
320
+ if clip_denoised:
321
+ return x.clamp(-1, 1)
322
+ return x
323
+
324
+ if self.model_mean_type == ModelMeanType.START_X:
325
+ pred_xstart = process_xstart(model_output)
326
+ else:
327
+ pred_xstart = process_xstart(
328
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
329
+ )
330
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
331
+
332
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
333
+ return {
334
+ "mean": model_mean,
335
+ "variance": model_variance,
336
+ "log_variance": model_log_variance,
337
+ "pred_xstart": pred_xstart,
338
+ "extra": extra,
339
+ }
340
+
341
+ def _predict_xstart_from_eps(self, x_t, t, eps):
342
+ assert x_t.shape == eps.shape
343
+ return (
344
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
345
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
346
+ )
347
+
348
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
349
+ return (
350
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
351
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
352
+
353
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
354
+ """
355
+ Compute the mean for the previous step, given a function cond_fn that
356
+ computes the gradient of a conditional log probability with respect to
357
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
358
+ condition on y.
359
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
360
+ """
361
+ gradient = cond_fn(x, t, **model_kwargs)
362
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
363
+ return new_mean
364
+
365
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
366
+ """
367
+ Compute what the p_mean_variance output would have been, should the
368
+ model's score function be conditioned by cond_fn.
369
+ See condition_mean() for details on cond_fn.
370
+ Unlike condition_mean(), this instead uses the conditioning strategy
371
+ from Song et al (2020).
372
+ """
373
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
374
+
375
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
376
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
377
+
378
+ out = p_mean_var.copy()
379
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
380
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
381
+ return out
382
+
383
+ def p_sample(
384
+ self,
385
+ model,
386
+ x,
387
+ t,
388
+ clip_denoised=True,
389
+ denoised_fn=None,
390
+ cond_fn=None,
391
+ model_kwargs=None,
392
+ temperature=1.0
393
+ ):
394
+ """
395
+ Sample x_{t-1} from the model at the given timestep.
396
+ :param model: the model to sample from.
397
+ :param x: the current tensor at x_{t-1}.
398
+ :param t: the value of t, starting at 0 for the first diffusion step.
399
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
400
+ :param denoised_fn: if not None, a function which applies to the
401
+ x_start prediction before it is used to sample.
402
+ :param cond_fn: if not None, this is a gradient function that acts
403
+ similarly to the model.
404
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
405
+ pass to the model. This can be used for conditioning.
406
+ :param temperature: temperature scaling during Diff Loss sampling.
407
+ :return: a dict containing the following keys:
408
+ - 'sample': a random sample from the model.
409
+ - 'pred_xstart': a prediction of x_0.
410
+ """
411
+ out = self.p_mean_variance(
412
+ model,
413
+ x,
414
+ t,
415
+ clip_denoised=clip_denoised,
416
+ denoised_fn=denoised_fn,
417
+ model_kwargs=model_kwargs,
418
+ )
419
+ noise = th.randn_like(x)
420
+ nonzero_mask = (
421
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
422
+ ) # no noise when t == 0
423
+ if cond_fn is not None:
424
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
425
+ # scale the noise by temperature
426
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
427
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
428
+
429
+ def p_sample_loop(
430
+ self,
431
+ model,
432
+ shape,
433
+ noise=None,
434
+ clip_denoised=True,
435
+ denoised_fn=None,
436
+ cond_fn=None,
437
+ model_kwargs=None,
438
+ device=None,
439
+ progress=False,
440
+ temperature=1.0,
441
+ ):
442
+ """
443
+ Generate samples from the model.
444
+ :param model: the model module.
445
+ :param shape: the shape of the samples, (N, C, H, W).
446
+ :param noise: if specified, the noise from the encoder to sample.
447
+ Should be of the same shape as `shape`.
448
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
449
+ :param denoised_fn: if not None, a function which applies to the
450
+ x_start prediction before it is used to sample.
451
+ :param cond_fn: if not None, this is a gradient function that acts
452
+ similarly to the model.
453
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
454
+ pass to the model. This can be used for conditioning.
455
+ :param device: if specified, the device to create the samples on.
456
+ If not specified, use a model parameter's device.
457
+ :param progress: if True, show a tqdm progress bar.
458
+ :param temperature: temperature scaling during Diff Loss sampling.
459
+ :return: a non-differentiable batch of samples.
460
+ """
461
+ final = None
462
+ for sample in self.p_sample_loop_progressive(
463
+ model,
464
+ shape,
465
+ noise=noise,
466
+ clip_denoised=clip_denoised,
467
+ denoised_fn=denoised_fn,
468
+ cond_fn=cond_fn,
469
+ model_kwargs=model_kwargs,
470
+ device=device,
471
+ progress=progress,
472
+ temperature=temperature,
473
+ ):
474
+ final = sample
475
+ return final["sample"]
476
+
477
+ def p_sample_loop_progressive(
478
+ self,
479
+ model,
480
+ shape,
481
+ noise=None,
482
+ clip_denoised=True,
483
+ denoised_fn=None,
484
+ cond_fn=None,
485
+ model_kwargs=None,
486
+ device=None,
487
+ progress=False,
488
+ temperature=1.0,
489
+ ):
490
+ """
491
+ Generate samples from the model and yield intermediate samples from
492
+ each timestep of diffusion.
493
+ Arguments are the same as p_sample_loop().
494
+ Returns a generator over dicts, where each dict is the return value of
495
+ p_sample().
496
+ """
497
+ assert isinstance(shape, (tuple, list))
498
+ if noise is not None:
499
+ img = noise
500
+ else:
501
+ img = th.randn(*shape).cuda()
502
+ indices = list(range(self.num_timesteps))[::-1]
503
+
504
+ if progress:
505
+ # Lazy import so that we don't depend on tqdm.
506
+ from tqdm.auto import tqdm
507
+
508
+ indices = tqdm(indices)
509
+
510
+ for i in indices:
511
+ t = th.tensor([i] * shape[0]).cuda()
512
+ with th.no_grad():
513
+ out = self.p_sample(
514
+ model,
515
+ img,
516
+ t,
517
+ clip_denoised=clip_denoised,
518
+ denoised_fn=denoised_fn,
519
+ cond_fn=cond_fn,
520
+ model_kwargs=model_kwargs,
521
+ temperature=temperature,
522
+ )
523
+ yield out
524
+ img = out["sample"]
525
+
526
+ def ddim_sample(
527
+ self,
528
+ model,
529
+ x,
530
+ t,
531
+ clip_denoised=True,
532
+ denoised_fn=None,
533
+ cond_fn=None,
534
+ model_kwargs=None,
535
+ eta=0.0,
536
+ ):
537
+ """
538
+ Sample x_{t-1} from the model using DDIM.
539
+ Same usage as p_sample().
540
+ """
541
+ out = self.p_mean_variance(
542
+ model,
543
+ x,
544
+ t,
545
+ clip_denoised=clip_denoised,
546
+ denoised_fn=denoised_fn,
547
+ model_kwargs=model_kwargs,
548
+ )
549
+ if cond_fn is not None:
550
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
551
+
552
+ # Usually our model outputs epsilon, but we re-derive it
553
+ # in case we used x_start or x_prev prediction.
554
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
555
+
556
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
557
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
558
+ sigma = (
559
+ eta
560
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
561
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
562
+ )
563
+ # Equation 12.
564
+ noise = th.randn_like(x)
565
+ mean_pred = (
566
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
567
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
568
+ )
569
+ nonzero_mask = (
570
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
571
+ ) # no noise when t == 0
572
+ sample = mean_pred + nonzero_mask * sigma * noise
573
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
574
+
575
+ def ddim_reverse_sample(
576
+ self,
577
+ model,
578
+ x,
579
+ t,
580
+ clip_denoised=True,
581
+ denoised_fn=None,
582
+ cond_fn=None,
583
+ model_kwargs=None,
584
+ eta=0.0,
585
+ ):
586
+ """
587
+ Sample x_{t+1} from the model using DDIM reverse ODE.
588
+ """
589
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
590
+ out = self.p_mean_variance(
591
+ model,
592
+ x,
593
+ t,
594
+ clip_denoised=clip_denoised,
595
+ denoised_fn=denoised_fn,
596
+ model_kwargs=model_kwargs,
597
+ )
598
+ if cond_fn is not None:
599
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
600
+ # Usually our model outputs epsilon, but we re-derive it
601
+ # in case we used x_start or x_prev prediction.
602
+ eps = (
603
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
604
+ - out["pred_xstart"]
605
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
606
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
607
+
608
+ # Equation 12. reversed
609
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
610
+
611
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
612
+
613
+ def ddim_sample_loop(
614
+ self,
615
+ model,
616
+ shape,
617
+ noise=None,
618
+ clip_denoised=True,
619
+ denoised_fn=None,
620
+ cond_fn=None,
621
+ model_kwargs=None,
622
+ device=None,
623
+ progress=False,
624
+ eta=0.0,
625
+ ):
626
+ """
627
+ Generate samples from the model using DDIM.
628
+ Same usage as p_sample_loop().
629
+ """
630
+ final = None
631
+ for sample in self.ddim_sample_loop_progressive(
632
+ model,
633
+ shape,
634
+ noise=noise,
635
+ clip_denoised=clip_denoised,
636
+ denoised_fn=denoised_fn,
637
+ cond_fn=cond_fn,
638
+ model_kwargs=model_kwargs,
639
+ device=device,
640
+ progress=progress,
641
+ eta=eta,
642
+ ):
643
+ final = sample
644
+ return final["sample"]
645
+
646
+ def ddim_sample_loop_progressive(
647
+ self,
648
+ model,
649
+ shape,
650
+ noise=None,
651
+ clip_denoised=True,
652
+ denoised_fn=None,
653
+ cond_fn=None,
654
+ model_kwargs=None,
655
+ device=None,
656
+ progress=False,
657
+ eta=0.0,
658
+ ):
659
+ """
660
+ Use DDIM to sample from the model and yield intermediate samples from
661
+ each timestep of DDIM.
662
+ Same usage as p_sample_loop_progressive().
663
+ """
664
+ assert isinstance(shape, (tuple, list))
665
+ if noise is not None:
666
+ img = noise
667
+ else:
668
+ img = th.randn(*shape).cuda()
669
+ indices = list(range(self.num_timesteps))[::-1]
670
+
671
+ if progress:
672
+ # Lazy import so that we don't depend on tqdm.
673
+ from tqdm.auto import tqdm
674
+
675
+ indices = tqdm(indices)
676
+
677
+ for i in indices:
678
+ t = th.tensor([i] * shape[0]).cuda()
679
+ with th.no_grad():
680
+ out = self.ddim_sample(
681
+ model,
682
+ img,
683
+ t,
684
+ clip_denoised=clip_denoised,
685
+ denoised_fn=denoised_fn,
686
+ cond_fn=cond_fn,
687
+ model_kwargs=model_kwargs,
688
+ eta=eta,
689
+ )
690
+ yield out
691
+ img = out["sample"]
692
+
693
+ def _vb_terms_bpd(
694
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
695
+ ):
696
+ """
697
+ Get a term for the variational lower-bound.
698
+ The resulting units are bits (rather than nats, as one might expect).
699
+ This allows for comparison to other papers.
700
+ :return: a dict with the following keys:
701
+ - 'output': a shape [N] tensor of NLLs or KLs.
702
+ - 'pred_xstart': the x_0 predictions.
703
+ """
704
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
705
+ x_start=x_start, x_t=x_t, t=t
706
+ )
707
+ out = self.p_mean_variance(
708
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
709
+ )
710
+ kl = normal_kl(
711
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
712
+ )
713
+ kl = mean_flat(kl) / np.log(2.0)
714
+
715
+ decoder_nll = -discretized_gaussian_log_likelihood(
716
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
717
+ )
718
+ assert decoder_nll.shape == x_start.shape
719
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
720
+
721
+ # At the first timestep return the decoder NLL,
722
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
723
+ output = th.where((t == 0), decoder_nll, kl)
724
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
725
+
726
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
727
+ """
728
+ Compute training losses for a single timestep.
729
+ :param model: the model to evaluate loss on.
730
+ :param x_start: the [N x C x ...] tensor of inputs.
731
+ :param t: a batch of timestep indices.
732
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
733
+ pass to the model. This can be used for conditioning.
734
+ :param noise: if specified, the specific Gaussian noise to try to remove.
735
+ :return: a dict with the key "loss" containing a tensor of shape [N].
736
+ Some mean or variance settings may also have other keys.
737
+ """
738
+ if model_kwargs is None:
739
+ model_kwargs = {}
740
+ if noise is None:
741
+ noise = th.randn_like(x_start)
742
+ x_t = self.q_sample(x_start, t, noise=noise)
743
+
744
+ terms = {}
745
+
746
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
747
+ terms["loss"] = self._vb_terms_bpd(
748
+ model=model,
749
+ x_start=x_start,
750
+ x_t=x_t,
751
+ t=t,
752
+ clip_denoised=False,
753
+ model_kwargs=model_kwargs,
754
+ )["output"]
755
+ if self.loss_type == LossType.RESCALED_KL:
756
+ terms["loss"] *= self.num_timesteps
757
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
758
+ model_output = model(x_t, t, **model_kwargs)
759
+
760
+ if self.model_var_type in [
761
+ ModelVarType.LEARNED,
762
+ ModelVarType.LEARNED_RANGE,
763
+ ]:
764
+ B, C = x_t.shape[:2]
765
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
766
+ model_output, model_var_values = th.split(model_output, C, dim=1)
767
+ # Learn the variance using the variational bound, but don't let
768
+ # it affect our mean prediction.
769
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
770
+ terms["vb"] = self._vb_terms_bpd(
771
+ model=lambda *args, r=frozen_out: r,
772
+ x_start=x_start,
773
+ x_t=x_t,
774
+ t=t,
775
+ clip_denoised=False,
776
+ )["output"]
777
+ if self.loss_type == LossType.RESCALED_MSE:
778
+ # Divide by 1000 for equivalence with initial implementation.
779
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
780
+ terms["vb"] *= self.num_timesteps / 1000.0
781
+
782
+ target = {
783
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
784
+ x_start=x_start, x_t=x_t, t=t
785
+ )[0],
786
+ ModelMeanType.START_X: x_start,
787
+ ModelMeanType.EPSILON: noise,
788
+ }[self.model_mean_type]
789
+ assert model_output.shape == target.shape == x_start.shape
790
+ terms["mse"] = mean_flat((target - model_output) ** 2)
791
+ if "vb" in terms:
792
+ terms["loss"] = terms["mse"] + terms["vb"]
793
+ else:
794
+ terms["loss"] = terms["mse"]
795
+ else:
796
+ raise NotImplementedError(self.loss_type)
797
+
798
+ return terms
799
+
800
+ def _prior_bpd(self, x_start):
801
+ """
802
+ Get the prior KL term for the variational lower-bound, measured in
803
+ bits-per-dim.
804
+ This term can't be optimized, as it only depends on the encoder.
805
+ :param x_start: the [N x C x ...] tensor of inputs.
806
+ :return: a batch of [N] KL values (in bits), one per batch element.
807
+ """
808
+ batch_size = x_start.shape[0]
809
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
810
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
811
+ kl_prior = normal_kl(
812
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
813
+ )
814
+ return mean_flat(kl_prior) / np.log(2.0)
815
+
816
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
817
+ """
818
+ Compute the entire variational lower-bound, measured in bits-per-dim,
819
+ as well as other related quantities.
820
+ :param model: the model to evaluate loss on.
821
+ :param x_start: the [N x C x ...] tensor of inputs.
822
+ :param clip_denoised: if True, clip denoised samples.
823
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
824
+ pass to the model. This can be used for conditioning.
825
+ :return: a dict containing the following keys:
826
+ - total_bpd: the total variational lower-bound, per batch element.
827
+ - prior_bpd: the prior term in the lower-bound.
828
+ - vb: an [N x T] tensor of terms in the lower-bound.
829
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
830
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
831
+ """
832
+ device = x_start.device
833
+ batch_size = x_start.shape[0]
834
+
835
+ vb = []
836
+ xstart_mse = []
837
+ mse = []
838
+ for t in list(range(self.num_timesteps))[::-1]:
839
+ t_batch = th.tensor([t] * batch_size, device=device)
840
+ noise = th.randn_like(x_start)
841
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
842
+ # Calculate VLB term at the current timestep
843
+ with th.no_grad():
844
+ out = self._vb_terms_bpd(
845
+ model,
846
+ x_start=x_start,
847
+ x_t=x_t,
848
+ t=t_batch,
849
+ clip_denoised=clip_denoised,
850
+ model_kwargs=model_kwargs,
851
+ )
852
+ vb.append(out["output"])
853
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
854
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
855
+ mse.append(mean_flat((eps - noise) ** 2))
856
+
857
+ vb = th.stack(vb, dim=1)
858
+ xstart_mse = th.stack(xstart_mse, dim=1)
859
+ mse = th.stack(mse, dim=1)
860
+
861
+ prior_bpd = self._prior_bpd(x_start)
862
+ total_bpd = vb.sum(dim=1) + prior_bpd
863
+ return {
864
+ "total_bpd": total_bpd,
865
+ "prior_bpd": prior_bpd,
866
+ "vb": vb,
867
+ "xstart_mse": xstart_mse,
868
+ "mse": mse,
869
+ }
870
+
871
+
872
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
873
+ """
874
+ Extract values from a 1-D numpy array for a batch of indices.
875
+ :param arr: the 1-D numpy array.
876
+ :param timesteps: a tensor of indices into the array to extract.
877
+ :param broadcast_shape: a larger shape of K dimensions with the batch
878
+ dimension equal to the length of timesteps.
879
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
880
+ """
881
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
882
+ while len(res.shape) < len(broadcast_shape):
883
+ res = res[..., None]
884
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
src/models/mar/diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
src/models/mar/engine_mar.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import src.models.mar.misc as misc
3
+ import torch_fidelity
4
+ import shutil
5
+ import cv2
6
+ import numpy as np
7
+ import os
8
+ import time
9
+
10
+
11
+ def torch_evaluate(model, args):
12
+ model.eval()
13
+ num_steps = args.num_images // (args.batch_size * misc.get_world_size()) + 1
14
+ save_folder = os.path.join(args.output_dir, "ariter{}-temp{}-{}cfg{}-image{}".format(
15
+ args.num_iter, args.temperature, args.cfg_schedule, args.cfg, args.num_images))
16
+
17
+ print("Save to:", save_folder)
18
+ if misc.get_rank() == 0:
19
+ if not os.path.exists(save_folder):
20
+ os.makedirs(save_folder)
21
+
22
+ class_num = args.class_num
23
+ assert args.num_images % class_num == 0 # number of images per class must be the same
24
+ class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
25
+ class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
26
+ world_size = misc.get_world_size()
27
+ local_rank = misc.get_rank()
28
+ used_time = 0
29
+ gen_img_cnt = 0
30
+
31
+ for i in range(num_steps):
32
+ print("Generation step {}/{}".format(i, num_steps))
33
+
34
+ labels_gen = class_label_gen_world[world_size * args.batch_size * i + local_rank * args.batch_size:
35
+ world_size * args.batch_size * i + (local_rank + 1) * args.batch_size]
36
+ labels_gen = torch.Tensor(labels_gen).long().cuda()
37
+
38
+ torch.cuda.synchronize()
39
+ start_time = time.time()
40
+
41
+ # generation
42
+ with torch.no_grad():
43
+ with torch.cuda.amp.autocast():
44
+ # sampled_images = model.sample_official(bsz=args.batch_size, num_iter=args.num_iter, cfg=args.cfg,
45
+ # cfg_schedule=args.cfg_schedule, labels=labels_gen,
46
+ # temperature=args.temperature)
47
+
48
+ import pdb; pdb.set_trace()
49
+ if args.cfg != 1.0:
50
+ labels_gen = torch.cat([
51
+ labels_gen, torch.full_like(labels_gen, fill_value=-1)])
52
+ sampled_images = model.sample(labels_gen,
53
+ num_iter=args.num_iter, cfg=args.cfg, cfg_schedule=args.cfg_schedule,
54
+ temperature=args.temperature, progress=False)
55
+
56
+ # measure speed after the first generation batch
57
+ if i >= 1:
58
+ torch.cuda.synchronize()
59
+ used_time += time.time() - start_time
60
+ gen_img_cnt += args.batch_size
61
+ print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
62
+
63
+ torch.distributed.barrier()
64
+ sampled_images = sampled_images.detach().cpu()
65
+ sampled_images = (sampled_images + 1) / 2
66
+
67
+ # distributed save
68
+ for b_id in range(sampled_images.size(0)):
69
+ img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
70
+ if img_id >= args.num_images:
71
+ break
72
+ gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
73
+ gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
74
+ cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
75
+
76
+ torch.distributed.barrier()
77
+ time.sleep(10)
78
+ if misc.get_rank() == 0:
79
+ input2 = None
80
+ fid_statistics_file = 'fid_stats/adm_in256_stats.npz'
81
+ metrics_dict = torch_fidelity.calculate_metrics(
82
+ input1=save_folder,
83
+ input2=input2,
84
+ fid_statistics_file=fid_statistics_file,
85
+ cuda=True,
86
+ isc=True,
87
+ fid=True,
88
+ kid=False,
89
+ prc=False,
90
+ verbose=True,
91
+ )
92
+ fid = metrics_dict['frechet_inception_distance']
93
+ inception_score = metrics_dict['inception_score_mean']
94
+ print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
95
+ # remove temporal saving folder
96
+ shutil.rmtree(save_folder)
97
+
98
+ torch.distributed.barrier()
99
+ time.sleep(10)
src/models/mar/mar.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import scipy.stats as stats
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.utils.checkpoint import checkpoint
12
+ from timm.models.vision_transformer import Block
13
+
14
+ from .diffloss import DiffLoss
15
+
16
+
17
+ def mask_by_order(mask_len, order, bsz, seq_len):
18
+ masking = torch.zeros(bsz, seq_len).to(order.device)
19
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
20
+ src=torch.ones(bsz, seq_len).to(order.device)).bool()
21
+ return masking
22
+
23
+
24
+ class MAR(nn.Module):
25
+ """ Masked Autoencoder with VisionTransformer backbone
26
+ """
27
+ def __init__(self, img_size=256, vae_stride=16, patch_size=1,
28
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
29
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
30
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
31
+ vae_embed_dim=16,
32
+ mask_ratio_min=0.7,
33
+ label_drop_prob=0.1,
34
+ class_num=1000,
35
+ attn_dropout=0.1,
36
+ proj_dropout=0.1,
37
+ buffer_size=64,
38
+ diffloss_d=3,
39
+ diffloss_w=1024,
40
+ num_sampling_steps='100',
41
+ diffusion_batch_mul=4,
42
+ grad_checkpointing=False,
43
+ ):
44
+ super().__init__()
45
+
46
+ # --------------------------------------------------------------------------
47
+ # VAE and patchify specifics
48
+ self.vae_embed_dim = vae_embed_dim
49
+
50
+ self.img_size = img_size
51
+ self.vae_stride = vae_stride
52
+ self.patch_size = patch_size
53
+ self.seq_h = self.seq_w = img_size // vae_stride // patch_size
54
+ self.seq_len = self.seq_h * self.seq_w
55
+ self.token_embed_dim = vae_embed_dim * patch_size**2
56
+ self.grad_checkpointing = grad_checkpointing
57
+
58
+ # --------------------------------------------------------------------------
59
+ # Class Embedding
60
+ self.num_classes = class_num
61
+ self.class_emb = nn.Embedding(class_num, encoder_embed_dim)
62
+ self.label_drop_prob = label_drop_prob
63
+ # Fake class embedding for CFG's unconditional generation
64
+ self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
65
+
66
+ # --------------------------------------------------------------------------
67
+ # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
68
+ self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
69
+
70
+ # --------------------------------------------------------------------------
71
+ # MAR encoder specifics
72
+ self.encoder_embed_dim = encoder_embed_dim
73
+ self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
74
+ self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
75
+ self.buffer_size = buffer_size
76
+ self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
77
+
78
+ self.encoder_blocks = nn.ModuleList([
79
+ Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
80
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
81
+ self.encoder_norm = norm_layer(encoder_embed_dim)
82
+
83
+ # --------------------------------------------------------------------------
84
+ # MAR decoder specifics
85
+ self.decoder_embed_dim = decoder_embed_dim
86
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
87
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
88
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
89
+
90
+ self.decoder_blocks = nn.ModuleList([
91
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
92
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
93
+
94
+ self.decoder_norm = norm_layer(decoder_embed_dim)
95
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
96
+
97
+ self.initialize_weights()
98
+
99
+ # --------------------------------------------------------------------------
100
+ # Diffusion Loss
101
+ self.diffloss = DiffLoss(
102
+ target_channels=self.token_embed_dim,
103
+ z_channels=decoder_embed_dim,
104
+ width=diffloss_w,
105
+ depth=diffloss_d,
106
+ num_sampling_steps=num_sampling_steps,
107
+ grad_checkpointing=self.grad_checkpointing
108
+ )
109
+ self.diffusion_batch_mul = diffusion_batch_mul
110
+
111
+ def get_encoder_pos_embed(self, h, w):
112
+ if h == self.seq_h and w == self.seq_w:
113
+ return self.encoder_pos_embed_learned
114
+ buffer_pe, image_pe = self.encoder_pos_embed_learned.split(
115
+ [self.buffer_size, self.seq_len], dim=1)
116
+ image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
117
+ h=self.seq_h, w=self.seq_w)
118
+ image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
119
+ image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
120
+
121
+ return torch.cat([buffer_pe, image_pe], dim=1)
122
+
123
+ def get_decoder_pos_embed(self, h, w):
124
+ if h == self.seq_h and w == self.seq_w:
125
+ return self.decoder_pos_embed_learned
126
+ buffer_pe, image_pe = self.decoder_pos_embed_learned.split(
127
+ [self.buffer_size, self.seq_len], dim=1)
128
+ image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
129
+ h=self.seq_h, w=self.seq_w)
130
+ image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
131
+ image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
132
+
133
+ return torch.cat([buffer_pe, image_pe], dim=1)
134
+
135
+ def get_diffusion_pos_embed(self, h, w):
136
+ if h == self.seq_h and w == self.seq_w:
137
+ return self.diffusion_pos_embed_learned
138
+ image_pe = self.diffusion_pos_embed_learned
139
+ image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
140
+ h=self.seq_h, w=self.seq_w)
141
+ image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
142
+ image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
143
+
144
+ return image_pe
145
+
146
+ def initialize_weights(self):
147
+ # parameters
148
+ torch.nn.init.normal_(self.class_emb.weight, std=.02)
149
+ torch.nn.init.normal_(self.fake_latent, std=.02)
150
+ torch.nn.init.normal_(self.mask_token, std=.02)
151
+ torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
152
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
153
+ torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
154
+
155
+ # initialize nn.Linear and nn.LayerNorm
156
+ self.apply(self._init_weights)
157
+
158
+ def _init_weights(self, m):
159
+ if isinstance(m, nn.Linear):
160
+ # we use xavier_uniform following official JAX ViT:
161
+ torch.nn.init.xavier_uniform_(m.weight)
162
+ if isinstance(m, nn.Linear) and m.bias is not None:
163
+ nn.init.constant_(m.bias, 0)
164
+ elif isinstance(m, nn.LayerNorm):
165
+ if m.bias is not None:
166
+ nn.init.constant_(m.bias, 0)
167
+ if m.weight is not None:
168
+ nn.init.constant_(m.weight, 1.0)
169
+
170
+ @property
171
+ def device(self):
172
+ return self.fake_latent.data.device
173
+
174
+ @property
175
+ def dtype(self):
176
+ return self.fake_latent.data.dtype
177
+
178
+ def patchify(self, x):
179
+ bsz, c, h, w = x.shape
180
+ p = self.patch_size
181
+ h_, w_ = h // p, w // p
182
+
183
+ x = x.reshape(bsz, c, h_, p, w_, p)
184
+ x = torch.einsum('nchpwq->nhwcpq', x)
185
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
186
+ return x # [n, l, d]
187
+
188
+ def unpatchify(self, x):
189
+ bsz = x.shape[0]
190
+ p = self.patch_size
191
+ c = self.vae_embed_dim
192
+ h_, w_ = self.seq_h, self.seq_w
193
+
194
+ x = x.reshape(bsz, h_, w_, c, p, p)
195
+ x = torch.einsum('nhwcpq->nchpwq', x)
196
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
197
+ return x # [n, c, h, w]
198
+
199
+ def sample_orders(self, bsz, seq_len=None):
200
+ if seq_len is None:
201
+ seq_len = self.seq_len
202
+ # generate a batch of random generation orders
203
+ orders = []
204
+ for _ in range(bsz):
205
+ order = np.array(list(range(seq_len)))
206
+ np.random.shuffle(order)
207
+ orders.append(order)
208
+ orders = torch.Tensor(np.array(orders)).to(self.device).long()
209
+ return orders
210
+
211
+ def random_masking(self, x, orders):
212
+ # generate token mask
213
+ bsz, seq_len, embed_dim = x.shape
214
+ assert seq_len == orders.shape[1]
215
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
216
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
217
+ mask = torch.zeros(bsz, seq_len, device=x.device)
218
+ mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
219
+ src=torch.ones(bsz, seq_len, device=x.device))
220
+ return mask
221
+
222
+ def forward_mae_encoder(self, x, mask, class_embedding, image_shape=None):
223
+ x = x.to(self.dtype)
224
+ x = self.z_proj(x)
225
+ bsz, seq_len, embed_dim = x.shape
226
+
227
+ # concat buffer
228
+ x = torch.cat([x.new_zeros(bsz, self.buffer_size, embed_dim), x], dim=1)
229
+ mask_with_buffer = torch.cat([mask.new_zeros(x.size(0), self.buffer_size), mask], dim=1)
230
+
231
+ # random drop class embedding during training
232
+ # if self.training:
233
+ # drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
234
+ # drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(self.device).to(x.dtype)
235
+ # class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
236
+
237
+ x[:, :self.buffer_size] = class_embedding.view(bsz, -1, embed_dim)
238
+
239
+ # encoder position embedding
240
+ # x = x + self.encoder_pos_embed_learned
241
+ if image_shape is None:
242
+ x = x + self.encoder_pos_embed_learned
243
+ else:
244
+ h, w = image_shape
245
+ assert h * w == seq_len
246
+ x = x + self.get_encoder_pos_embed(h=h, w=w)
247
+ # import pdb; pdb.set_trace()
248
+ x = self.z_proj_ln(x)
249
+
250
+ # dropping
251
+ x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
252
+
253
+ # apply Transformer blocks
254
+ if self.grad_checkpointing and not torch.jit.is_scripting():
255
+ for block in self.encoder_blocks:
256
+ x = checkpoint(block, x,
257
+ use_reentrant=False
258
+ )
259
+ else:
260
+ for block in self.encoder_blocks:
261
+ x = block(x)
262
+ x = self.encoder_norm(x)
263
+
264
+ return x
265
+
266
+ def forward_mae_decoder(self, x, mask, image_shape=None, x_con=None):
267
+ bsz, seq_len = mask.shape
268
+
269
+ x = self.decoder_embed(x)
270
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
271
+
272
+ # pad mask tokens
273
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
274
+
275
+ if x_con is not None:
276
+ x_after_pad = self.decoder_embed(x_con)
277
+ else:
278
+ x_after_pad = mask_tokens.clone()
279
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
280
+
281
+ # decoder position embedding
282
+ # x = x_after_pad + self.decoder_pos_embed_learned
283
+ if image_shape is None:
284
+ x = x_after_pad + self.decoder_pos_embed_learned
285
+ else:
286
+ h, w = image_shape
287
+ assert h * w == seq_len
288
+ x = x_after_pad + self.get_decoder_pos_embed(h=h, w=w)
289
+
290
+ # apply Transformer blocks
291
+ if self.grad_checkpointing and not torch.jit.is_scripting():
292
+ for block in self.decoder_blocks:
293
+ x = checkpoint(block, x,
294
+ # use_reentrant=False
295
+ )
296
+ else:
297
+ for block in self.decoder_blocks:
298
+ x = block(x)
299
+ x = self.decoder_norm(x)
300
+
301
+ x = x[:, self.buffer_size:]
302
+ # x = x + self.diffusion_pos_embed_learned
303
+ if image_shape is None:
304
+ x = x + self.diffusion_pos_embed_learned
305
+ else:
306
+ h, w = image_shape
307
+ assert h * w == seq_len
308
+ x = x + self.get_diffusion_pos_embed(h=h, w=w)
309
+ return x
310
+
311
+ def mae_decoder_prepare(self, x, mask):
312
+ x = self.decoder_embed(x)
313
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
314
+
315
+ # pad mask tokens
316
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
317
+ x_after_pad = mask_tokens.clone()
318
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
319
+
320
+ # decoder position embedding
321
+ x = x_after_pad + self.decoder_pos_embed_learned
322
+
323
+ return x
324
+
325
+ def mae_decoder_forward(self, x):
326
+ # apply Transformer blocks
327
+ if self.grad_checkpointing and not torch.jit.is_scripting():
328
+ for block in self.decoder_blocks:
329
+ x = checkpoint(block, x,
330
+ # use_reentrant=False
331
+ )
332
+ else:
333
+ for block in self.decoder_blocks:
334
+ x = block(x)
335
+ x = self.decoder_norm(x)
336
+
337
+ x = x[:, self.buffer_size:]
338
+ x = x + self.diffusion_pos_embed_learned
339
+ return x
340
+
341
+ def forward_loss(self, z, target, mask):
342
+ bsz, seq_len, _ = target.shape
343
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
344
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
345
+ mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
346
+ loss = self.diffloss(z=z, target=target, mask=mask)
347
+ return loss
348
+
349
+ def forward(self, imgs, labels):
350
+
351
+ # class embed
352
+ class_embedding = self.class_emb(labels)
353
+
354
+ # patchify and mask (drop) tokens
355
+ x = self.patchify(imgs)
356
+ gt_latents = x.clone().detach()
357
+ orders = self.sample_orders(bsz=x.size(0))
358
+ mask = self.random_masking(x, orders)
359
+
360
+ # mae encoder
361
+ x = self.forward_mae_encoder(x, mask, class_embedding)
362
+
363
+ # mae decoder
364
+ z = self.forward_mae_decoder(x, mask)
365
+
366
+ # diffloss
367
+ loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
368
+
369
+ return loss
370
+
371
+ def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
372
+ import pdb; pdb.set_trace()
373
+ # init and sample generation orders
374
+ mask = torch.ones(bsz, self.seq_len).to(self.device)
375
+ tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(self.device)
376
+ orders = self.sample_orders(bsz)
377
+
378
+ indices = list(range(num_iter))
379
+ if progress:
380
+ indices = tqdm(indices)
381
+ # generate latents
382
+ for step in indices:
383
+ cur_tokens = tokens.clone()
384
+
385
+ # class embedding and CFG
386
+ if labels is not None:
387
+ class_embedding = self.class_emb(labels)
388
+ else:
389
+ class_embedding = self.fake_latent.repeat(bsz, 1)
390
+ if not cfg == 1.0:
391
+ tokens = torch.cat([tokens, tokens], dim=0)
392
+ class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
393
+ mask = torch.cat([mask, mask], dim=0)
394
+
395
+ # mae encoder
396
+ x = self.forward_mae_encoder(tokens, mask.to(self.dtype), class_embedding)
397
+
398
+ # mae decoder
399
+ z = self.forward_mae_decoder(x, mask.to(self.dtype))
400
+ import pdb; pdb.set_trace()
401
+
402
+ # mask ratio for the next round, following MaskGIT and MAGE.
403
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
404
+ mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(self.device)
405
+ import pdb; pdb.set_trace()
406
+ # masks out at least one for the next iteration
407
+ mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
408
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
409
+ import pdb; pdb.set_trace()
410
+ # get masking for next iteration and locations to be predicted in this iteration
411
+ mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
412
+ import pdb; pdb.set_trace()
413
+ if step >= num_iter - 1:
414
+ mask_to_pred = mask[:bsz].bool()
415
+ else:
416
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
417
+ mask = mask_next
418
+ if not cfg == 1.0:
419
+ mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
420
+ import pdb; pdb.set_trace()
421
+ # sample token latents for this step
422
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
423
+ # cfg schedule follow Muse
424
+ if cfg_schedule == "linear":
425
+ cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
426
+ elif cfg_schedule == "constant":
427
+ cfg_iter = cfg
428
+ else:
429
+ raise NotImplementedError
430
+ sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
431
+ if not cfg == 1.0:
432
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
433
+ mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
434
+ import pdb; pdb.set_trace()
435
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
436
+ tokens = cur_tokens.clone()
437
+
438
+ # unpatchify
439
+ tokens = self.unpatchify(tokens)
440
+ return tokens
441
+
442
+ def gradient_checkpointing_enable(self):
443
+ self.grad_checkpointing = True
444
+
445
+ def gradient_checkpointing_disable(self):
446
+ self.grad_checkpointing = False
447
+
448
+
449
+ def mar_base(**kwargs):
450
+ model = MAR(
451
+ encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
452
+ decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
453
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
454
+ return model
455
+
456
+
457
+ def mar_large(**kwargs):
458
+ model = MAR(
459
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
460
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
461
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
462
+ return model
463
+
464
+
465
+ def mar_huge(**kwargs):
466
+ model = MAR(
467
+ encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
468
+ decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
469
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
470
+ return model
471
+
472
+ def mar_max(**kwargs):
473
+ model = MAR(
474
+ encoder_embed_dim=1536, encoder_depth=24, encoder_num_heads=16,
475
+ decoder_embed_dim=1536, decoder_depth=24, decoder_num_heads=16,
476
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
477
+ return model
src/models/mar/misc.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import datetime
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
11
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
12
+
13
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
14
+ from torch._six import inf
15
+ else:
16
+ from torch import inf
17
+ import copy
18
+
19
+
20
+ class SmoothedValue(object):
21
+ """Track a series of values and provide access to smoothed values over a
22
+ window or the global series average.
23
+ """
24
+
25
+ def __init__(self, window_size=20, fmt=None):
26
+ if fmt is None:
27
+ fmt = "{median:.4f} ({global_avg:.4f})"
28
+ self.deque = deque(maxlen=window_size)
29
+ self.total = 0.0
30
+ self.count = 0
31
+ self.fmt = fmt
32
+
33
+ def update(self, value, n=1):
34
+ self.deque.append(value)
35
+ self.count += n
36
+ self.total += value * n
37
+
38
+ def synchronize_between_processes(self):
39
+ """
40
+ Warning: does not synchronize the deque!
41
+ """
42
+ if not is_dist_avail_and_initialized():
43
+ return
44
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
45
+ dist.barrier()
46
+ dist.all_reduce(t)
47
+ t = t.tolist()
48
+ self.count = int(t[0])
49
+ self.total = t[1]
50
+
51
+ @property
52
+ def median(self):
53
+ d = torch.tensor(list(self.deque))
54
+ return d.median().item()
55
+
56
+ @property
57
+ def avg(self):
58
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
59
+ return d.mean().item()
60
+
61
+ @property
62
+ def global_avg(self):
63
+ return self.total / self.count
64
+
65
+ @property
66
+ def max(self):
67
+ return max(self.deque)
68
+
69
+ @property
70
+ def value(self):
71
+ return self.deque[-1]
72
+
73
+ def __str__(self):
74
+ return self.fmt.format(
75
+ median=self.median,
76
+ avg=self.avg,
77
+ global_avg=self.global_avg,
78
+ max=self.max,
79
+ value=self.value)
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if v is None:
90
+ continue
91
+ if isinstance(v, torch.Tensor):
92
+ v = v.item()
93
+ assert isinstance(v, (float, int))
94
+ self.meters[k].update(v)
95
+
96
+ def __getattr__(self, attr):
97
+ if attr in self.meters:
98
+ return self.meters[attr]
99
+ if attr in self.__dict__:
100
+ return self.__dict__[attr]
101
+ raise AttributeError("'{}' object has no attribute '{}'".format(
102
+ type(self).__name__, attr))
103
+
104
+ def __str__(self):
105
+ loss_str = []
106
+ for name, meter in self.meters.items():
107
+ loss_str.append(
108
+ "{}: {}".format(name, str(meter))
109
+ )
110
+ return self.delimiter.join(loss_str)
111
+
112
+ def synchronize_between_processes(self):
113
+ for meter in self.meters.values():
114
+ meter.synchronize_between_processes()
115
+
116
+ def add_meter(self, name, meter):
117
+ self.meters[name] = meter
118
+
119
+ def log_every(self, iterable, print_freq, header=None):
120
+ i = 0
121
+ if not header:
122
+ header = ''
123
+ start_time = time.time()
124
+ end = time.time()
125
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
126
+ data_time = SmoothedValue(fmt='{avg:.4f}')
127
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
128
+ log_msg = [
129
+ header,
130
+ '[{0' + space_fmt + '}/{1}]',
131
+ 'eta: {eta}',
132
+ '{meters}',
133
+ 'time: {time}',
134
+ 'data: {data}'
135
+ ]
136
+ if torch.cuda.is_available():
137
+ log_msg.append('max mem: {memory:.0f}')
138
+ log_msg = self.delimiter.join(log_msg)
139
+ MB = 1024.0 * 1024.0
140
+ for obj in iterable:
141
+ data_time.update(time.time() - end)
142
+ yield obj
143
+ iter_time.update(time.time() - end)
144
+ if i % print_freq == 0 or i == len(iterable) - 1:
145
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
146
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
147
+ if torch.cuda.is_available():
148
+ print(log_msg.format(
149
+ i, len(iterable), eta=eta_string,
150
+ meters=str(self),
151
+ time=str(iter_time), data=str(data_time),
152
+ memory=torch.cuda.max_memory_allocated() / MB))
153
+ else:
154
+ print(log_msg.format(
155
+ i, len(iterable), eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time), data=str(data_time)))
158
+ i += 1
159
+ end = time.time()
160
+ total_time = time.time() - start_time
161
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162
+ print('{} Total time: {} ({:.4f} s / it)'.format(
163
+ header, total_time_str, total_time / len(iterable)))
164
+
165
+
166
+ def setup_for_distributed(is_master):
167
+ """
168
+ This function disables printing when not in master process
169
+ """
170
+ builtin_print = builtins.print
171
+
172
+ def print(*args, **kwargs):
173
+ force = kwargs.pop('force', False)
174
+ force = force or (get_world_size() > 8)
175
+ if is_master or force:
176
+ now = datetime.datetime.now().time()
177
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
178
+ builtin_print(*args, **kwargs)
179
+
180
+ builtins.print = print
181
+
182
+
183
+ def is_dist_avail_and_initialized():
184
+ if not dist.is_available():
185
+ return False
186
+ if not dist.is_initialized():
187
+ return False
188
+ return True
189
+
190
+
191
+ def get_world_size():
192
+ if not is_dist_avail_and_initialized():
193
+ return 1
194
+ return dist.get_world_size()
195
+
196
+
197
+ def get_rank():
198
+ if not is_dist_avail_and_initialized():
199
+ return 0
200
+ return dist.get_rank()
201
+
202
+
203
+ def is_main_process():
204
+ return get_rank() == 0
205
+
206
+
207
+ def save_on_master(*args, **kwargs):
208
+ if is_main_process():
209
+ torch.save(*args, **kwargs)
210
+
211
+
212
+ def init_distributed_mode(args):
213
+ if args.dist_on_itp:
214
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
215
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
216
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
217
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
218
+ os.environ['LOCAL_RANK'] = str(args.gpu)
219
+ os.environ['RANK'] = str(args.rank)
220
+ os.environ['WORLD_SIZE'] = str(args.world_size)
221
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
222
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
223
+ args.rank = int(os.environ["RANK"])
224
+ args.world_size = int(os.environ['WORLD_SIZE'])
225
+ args.gpu = int(os.environ['LOCAL_RANK'])
226
+ elif 'SLURM_PROCID' in os.environ:
227
+ args.rank = int(os.environ['SLURM_PROCID'])
228
+ args.gpu = args.rank % torch.cuda.device_count()
229
+ else:
230
+ print('Not using distributed mode')
231
+ setup_for_distributed(is_master=True) # hack
232
+ args.distributed = False
233
+ return
234
+
235
+ args.distributed = True
236
+
237
+ torch.cuda.set_device(args.gpu)
238
+ args.dist_backend = 'nccl'
239
+ print('| distributed init (rank {}): {}, gpu {}'.format(
240
+ args.rank, args.dist_url, args.gpu), flush=True)
241
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
242
+ world_size=args.world_size, rank=args.rank)
243
+ torch.distributed.barrier()
244
+ setup_for_distributed(args.rank == 0)
245
+
246
+
247
+ class NativeScalerWithGradNormCount:
248
+ state_dict_key = "amp_scaler"
249
+
250
+ def __init__(self):
251
+ self._scaler = torch.cuda.amp.GradScaler()
252
+
253
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
254
+ self._scaler.scale(loss).backward(create_graph=create_graph)
255
+ if update_grad:
256
+ if clip_grad is not None:
257
+ assert parameters is not None
258
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
259
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
260
+ else:
261
+ self._scaler.unscale_(optimizer)
262
+ norm = get_grad_norm_(parameters)
263
+ self._scaler.step(optimizer)
264
+ self._scaler.update()
265
+ else:
266
+ norm = None
267
+ return norm
268
+
269
+ def state_dict(self):
270
+ return self._scaler.state_dict()
271
+
272
+ def load_state_dict(self, state_dict):
273
+ self._scaler.load_state_dict(state_dict)
274
+
275
+
276
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
277
+ if isinstance(parameters, torch.Tensor):
278
+ parameters = [parameters]
279
+ parameters = [p for p in parameters if p.grad is not None]
280
+ norm_type = float(norm_type)
281
+ if len(parameters) == 0:
282
+ return torch.tensor(0.)
283
+ device = parameters[0].grad.device
284
+ if norm_type == inf:
285
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
286
+ else:
287
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
288
+ return total_norm
289
+
290
+
291
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
292
+ decay = []
293
+ no_decay = []
294
+ for name, param in model.named_parameters():
295
+ if not param.requires_grad:
296
+ continue # frozen weights
297
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
298
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
299
+ else:
300
+ decay.append(param)
301
+ return [
302
+ {'params': no_decay, 'weight_decay': 0.},
303
+ {'params': decay, 'weight_decay': weight_decay}]
304
+
305
+
306
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
307
+ if epoch_name is None:
308
+ epoch_name = str(epoch)
309
+ output_dir = Path(args.output_dir)
310
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
311
+
312
+ # ema
313
+ if ema_params is not None:
314
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
315
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
316
+ assert name in ema_state_dict
317
+ ema_state_dict[name] = ema_params[i]
318
+ else:
319
+ ema_state_dict = None
320
+
321
+ to_save = {
322
+ 'model': model_without_ddp.state_dict(),
323
+ 'model_ema': ema_state_dict,
324
+ 'optimizer': optimizer.state_dict(),
325
+ 'epoch': epoch,
326
+ 'scaler': loss_scaler.state_dict(),
327
+ 'args': args,
328
+ }
329
+ save_on_master(to_save, checkpoint_path)
330
+
331
+
332
+ def all_reduce_mean(x):
333
+ world_size = get_world_size()
334
+ if world_size > 1:
335
+ x_reduce = torch.tensor(x).cuda()
336
+ dist.all_reduce(x_reduce)
337
+ x_reduce /= world_size
338
+ return x_reduce.item()
339
+ else:
340
+ return x
src/models/mar/vae.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import numpy as np
6
+
7
+
8
+ def nonlinearity(x):
9
+ # swish
10
+ return x * torch.sigmoid(x)
11
+
12
+
13
+ def Normalize(in_channels, num_groups=32):
14
+ return torch.nn.GroupNorm(
15
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
16
+ )
17
+
18
+
19
+ class Upsample(nn.Module):
20
+ def __init__(self, in_channels, with_conv):
21
+ super().__init__()
22
+ self.with_conv = with_conv
23
+ if self.with_conv:
24
+ self.conv = torch.nn.Conv2d(
25
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
30
+ if self.with_conv:
31
+ x = self.conv(x)
32
+ return x
33
+
34
+
35
+ class Downsample(nn.Module):
36
+ def __init__(self, in_channels, with_conv):
37
+ super().__init__()
38
+ self.with_conv = with_conv
39
+ if self.with_conv:
40
+ # no asymmetric padding in torch conv, must do it ourselves
41
+ self.conv = torch.nn.Conv2d(
42
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
43
+ )
44
+
45
+ def forward(self, x):
46
+ if self.with_conv:
47
+ pad = (0, 1, 0, 1)
48
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
49
+ x = self.conv(x)
50
+ else:
51
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
52
+ return x
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(
57
+ self,
58
+ *,
59
+ in_channels,
60
+ out_channels=None,
61
+ conv_shortcut=False,
62
+ dropout,
63
+ temb_channels=512,
64
+ ):
65
+ super().__init__()
66
+ self.in_channels = in_channels
67
+ out_channels = in_channels if out_channels is None else out_channels
68
+ self.out_channels = out_channels
69
+ self.use_conv_shortcut = conv_shortcut
70
+
71
+ self.norm1 = Normalize(in_channels)
72
+ self.conv1 = torch.nn.Conv2d(
73
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
74
+ )
75
+ if temb_channels > 0:
76
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
77
+ self.norm2 = Normalize(out_channels)
78
+ self.dropout = torch.nn.Dropout(dropout)
79
+ self.conv2 = torch.nn.Conv2d(
80
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
81
+ )
82
+ if self.in_channels != self.out_channels:
83
+ if self.use_conv_shortcut:
84
+ self.conv_shortcut = torch.nn.Conv2d(
85
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
86
+ )
87
+ else:
88
+ self.nin_shortcut = torch.nn.Conv2d(
89
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
90
+ )
91
+
92
+ def forward(self, x, temb):
93
+ h = x
94
+ h = self.norm1(h)
95
+ h = nonlinearity(h)
96
+ h = self.conv1(h)
97
+
98
+ if temb is not None:
99
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
100
+
101
+ h = self.norm2(h)
102
+ h = nonlinearity(h)
103
+ h = self.dropout(h)
104
+ h = self.conv2(h)
105
+
106
+ if self.in_channels != self.out_channels:
107
+ if self.use_conv_shortcut:
108
+ x = self.conv_shortcut(x)
109
+ else:
110
+ x = self.nin_shortcut(x)
111
+
112
+ return x + h
113
+
114
+
115
+ class AttnBlock(nn.Module):
116
+ def __init__(self, in_channels):
117
+ super().__init__()
118
+ self.in_channels = in_channels
119
+
120
+ self.norm = Normalize(in_channels)
121
+ self.q = torch.nn.Conv2d(
122
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
123
+ )
124
+ self.k = torch.nn.Conv2d(
125
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+ self.v = torch.nn.Conv2d(
128
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
129
+ )
130
+ self.proj_out = torch.nn.Conv2d(
131
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
132
+ )
133
+
134
+ def forward(self, x):
135
+ h_ = x
136
+ h_ = self.norm(h_)
137
+ q = self.q(h_)
138
+ k = self.k(h_)
139
+ v = self.v(h_)
140
+
141
+ # compute attention
142
+ b, c, h, w = q.shape
143
+ q = q.reshape(b, c, h * w)
144
+ q = q.permute(0, 2, 1) # b,hw,c
145
+ k = k.reshape(b, c, h * w) # b,c,hw
146
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
147
+ w_ = w_ * (int(c) ** (-0.5))
148
+ w_ = torch.nn.functional.softmax(w_, dim=2)
149
+
150
+ # attend to values
151
+ v = v.reshape(b, c, h * w)
152
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
153
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
154
+ h_ = h_.reshape(b, c, h, w)
155
+
156
+ h_ = self.proj_out(h_)
157
+
158
+ return x + h_
159
+
160
+
161
+ class Encoder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ *,
165
+ ch=128,
166
+ out_ch=3,
167
+ ch_mult=(1, 1, 2, 2, 4),
168
+ num_res_blocks=2,
169
+ attn_resolutions=(16,),
170
+ dropout=0.0,
171
+ resamp_with_conv=True,
172
+ in_channels=3,
173
+ resolution=256,
174
+ z_channels=16,
175
+ double_z=True,
176
+ **ignore_kwargs,
177
+ ):
178
+ super().__init__()
179
+ self.ch = ch
180
+ self.temb_ch = 0
181
+ self.num_resolutions = len(ch_mult)
182
+ self.num_res_blocks = num_res_blocks
183
+ self.resolution = resolution
184
+ self.in_channels = in_channels
185
+
186
+ # downsampling
187
+ self.conv_in = torch.nn.Conv2d(
188
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
189
+ )
190
+
191
+ curr_res = resolution
192
+ in_ch_mult = (1,) + tuple(ch_mult)
193
+ self.down = nn.ModuleList()
194
+ for i_level in range(self.num_resolutions):
195
+ block = nn.ModuleList()
196
+ attn = nn.ModuleList()
197
+ block_in = ch * in_ch_mult[i_level]
198
+ block_out = ch * ch_mult[i_level]
199
+ for i_block in range(self.num_res_blocks):
200
+ block.append(
201
+ ResnetBlock(
202
+ in_channels=block_in,
203
+ out_channels=block_out,
204
+ temb_channels=self.temb_ch,
205
+ dropout=dropout,
206
+ )
207
+ )
208
+ block_in = block_out
209
+ if curr_res in attn_resolutions:
210
+ attn.append(AttnBlock(block_in))
211
+ down = nn.Module()
212
+ down.block = block
213
+ down.attn = attn
214
+ if i_level != self.num_resolutions - 1:
215
+ down.downsample = Downsample(block_in, resamp_with_conv)
216
+ curr_res = curr_res // 2
217
+ self.down.append(down)
218
+
219
+ # middle
220
+ self.mid = nn.Module()
221
+ self.mid.block_1 = ResnetBlock(
222
+ in_channels=block_in,
223
+ out_channels=block_in,
224
+ temb_channels=self.temb_ch,
225
+ dropout=dropout,
226
+ )
227
+ self.mid.attn_1 = AttnBlock(block_in)
228
+ self.mid.block_2 = ResnetBlock(
229
+ in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=dropout,
233
+ )
234
+
235
+ # end
236
+ self.norm_out = Normalize(block_in)
237
+ self.conv_out = torch.nn.Conv2d(
238
+ block_in,
239
+ 2 * z_channels if double_z else z_channels,
240
+ kernel_size=3,
241
+ stride=1,
242
+ padding=1,
243
+ )
244
+
245
+ def forward(self, x):
246
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
247
+
248
+ # timestep embedding
249
+ temb = None
250
+
251
+ # downsampling
252
+ hs = [self.conv_in(x)]
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](hs[-1], temb)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ hs.append(h)
259
+ if i_level != self.num_resolutions - 1:
260
+ hs.append(self.down[i_level].downsample(hs[-1]))
261
+
262
+ # middle
263
+ h = hs[-1]
264
+ h = self.mid.block_1(h, temb)
265
+ h = self.mid.attn_1(h)
266
+ h = self.mid.block_2(h, temb)
267
+
268
+ # end
269
+ h = self.norm_out(h)
270
+ h = nonlinearity(h)
271
+ h = self.conv_out(h)
272
+ return h
273
+
274
+
275
+ class Decoder(nn.Module):
276
+ def __init__(
277
+ self,
278
+ *,
279
+ ch=128,
280
+ out_ch=3,
281
+ ch_mult=(1, 1, 2, 2, 4),
282
+ num_res_blocks=2,
283
+ attn_resolutions=(),
284
+ dropout=0.0,
285
+ resamp_with_conv=True,
286
+ in_channels=3,
287
+ resolution=256,
288
+ z_channels=16,
289
+ give_pre_end=False,
290
+ **ignore_kwargs,
291
+ ):
292
+ super().__init__()
293
+ self.ch = ch
294
+ self.temb_ch = 0
295
+ self.num_resolutions = len(ch_mult)
296
+ self.num_res_blocks = num_res_blocks
297
+ self.resolution = resolution
298
+ self.in_channels = in_channels
299
+ self.give_pre_end = give_pre_end
300
+
301
+ # compute in_ch_mult, block_in and curr_res at lowest res
302
+ in_ch_mult = (1,) + tuple(ch_mult)
303
+ block_in = ch * ch_mult[self.num_resolutions - 1]
304
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
305
+ self.z_shape = (1, z_channels, curr_res, curr_res)
306
+ print(
307
+ "Working with z of shape {} = {} dimensions.".format(
308
+ self.z_shape, np.prod(self.z_shape)
309
+ )
310
+ )
311
+
312
+ # z to block_in
313
+ self.conv_in = torch.nn.Conv2d(
314
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
315
+ )
316
+
317
+ # middle
318
+ self.mid = nn.Module()
319
+ self.mid.block_1 = ResnetBlock(
320
+ in_channels=block_in,
321
+ out_channels=block_in,
322
+ temb_channels=self.temb_ch,
323
+ dropout=dropout,
324
+ )
325
+ self.mid.attn_1 = AttnBlock(block_in)
326
+ self.mid.block_2 = ResnetBlock(
327
+ in_channels=block_in,
328
+ out_channels=block_in,
329
+ temb_channels=self.temb_ch,
330
+ dropout=dropout,
331
+ )
332
+
333
+ # upsampling
334
+ self.up = nn.ModuleList()
335
+ for i_level in reversed(range(self.num_resolutions)):
336
+ block = nn.ModuleList()
337
+ attn = nn.ModuleList()
338
+ block_out = ch * ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks + 1):
340
+ block.append(
341
+ ResnetBlock(
342
+ in_channels=block_in,
343
+ out_channels=block_out,
344
+ temb_channels=self.temb_ch,
345
+ dropout=dropout,
346
+ )
347
+ )
348
+ block_in = block_out
349
+ if curr_res in attn_resolutions:
350
+ attn.append(AttnBlock(block_in))
351
+ up = nn.Module()
352
+ up.block = block
353
+ up.attn = attn
354
+ if i_level != 0:
355
+ up.upsample = Upsample(block_in, resamp_with_conv)
356
+ curr_res = curr_res * 2
357
+ self.up.insert(0, up) # prepend to get consistent order
358
+
359
+ # end
360
+ self.norm_out = Normalize(block_in)
361
+ self.conv_out = torch.nn.Conv2d(
362
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
363
+ )
364
+
365
+ def forward(self, z):
366
+ # assert z.shape[1:] == self.z_shape[1:]
367
+ self.last_z_shape = z.shape
368
+
369
+ # timestep embedding
370
+ temb = None
371
+
372
+ # z to block_in
373
+ h = self.conv_in(z)
374
+
375
+ # middle
376
+ h = self.mid.block_1(h, temb)
377
+ h = self.mid.attn_1(h)
378
+ h = self.mid.block_2(h, temb)
379
+
380
+ # upsampling
381
+ for i_level in reversed(range(self.num_resolutions)):
382
+ for i_block in range(self.num_res_blocks + 1):
383
+ h = self.up[i_level].block[i_block](h, temb)
384
+ if len(self.up[i_level].attn) > 0:
385
+ h = self.up[i_level].attn[i_block](h)
386
+ if i_level != 0:
387
+ h = self.up[i_level].upsample(h)
388
+
389
+ # end
390
+ if self.give_pre_end:
391
+ return h
392
+
393
+ h = self.norm_out(h)
394
+ h = nonlinearity(h)
395
+ h = self.conv_out(h)
396
+ return h
397
+
398
+
399
+ class DiagonalGaussianDistribution(object):
400
+ def __init__(self, parameters, deterministic=False):
401
+ self.parameters = parameters
402
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
403
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
404
+ self.deterministic = deterministic
405
+ self.std = torch.exp(0.5 * self.logvar)
406
+ self.var = torch.exp(self.logvar)
407
+ if self.deterministic:
408
+ self.var = self.std = torch.zeros_like(self.mean).to(
409
+ device=self.parameters.device
410
+ )
411
+
412
+ def sample(self):
413
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
414
+ device=self.parameters.device
415
+ )
416
+ return x
417
+
418
+ def kl(self, other=None):
419
+ if self.deterministic:
420
+ return torch.Tensor([0.0])
421
+ else:
422
+ if other is None:
423
+ return 0.5 * torch.sum(
424
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
425
+ dim=[1, 2, 3],
426
+ )
427
+ else:
428
+ return 0.5 * torch.sum(
429
+ torch.pow(self.mean - other.mean, 2) / other.var
430
+ + self.var / other.var
431
+ - 1.0
432
+ - self.logvar
433
+ + other.logvar,
434
+ dim=[1, 2, 3],
435
+ )
436
+
437
+ def nll(self, sample, dims=[1, 2, 3]):
438
+ if self.deterministic:
439
+ return torch.Tensor([0.0])
440
+ logtwopi = np.log(2.0 * np.pi)
441
+ return 0.5 * torch.sum(
442
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
443
+ dim=dims,
444
+ )
445
+
446
+ def mode(self):
447
+ return self.mean
448
+
449
+
450
+ class AutoencoderKL(nn.Module):
451
+ def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None):
452
+ super().__init__()
453
+ self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim)
454
+ self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim)
455
+ self.use_variational = use_variational
456
+ mult = 2 if self.use_variational else 1
457
+ self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1)
458
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
459
+ self.embed_dim = embed_dim
460
+ if ckpt_path is not None:
461
+ self.init_from_ckpt(ckpt_path)
462
+
463
+ def init_from_ckpt(self, path):
464
+ sd = torch.load(path, map_location="cpu")["model"]
465
+ msg = self.load_state_dict(sd, strict=False)
466
+ print("Loading pre-trained KL-VAE")
467
+ print("Missing keys:")
468
+ print(msg.missing_keys)
469
+ print("Unexpected keys:")
470
+ print(msg.unexpected_keys)
471
+ print(f"Restored from {path}")
472
+
473
+ def encode(self, x):
474
+ h = self.encoder(x)
475
+ moments = self.quant_conv(h)
476
+ if not self.use_variational:
477
+ moments = torch.cat((moments, torch.ones_like(moments)), 1)
478
+ posterior = DiagonalGaussianDistribution(moments)
479
+ return posterior
480
+
481
+ def decode(self, z):
482
+ z = self.post_quant_conv(z)
483
+ dec = self.decoder(z)
484
+ return dec
485
+
486
+ def forward(self, inputs, disable=True, train=True, optimizer_idx=0):
487
+ if train:
488
+ return self.training_step(inputs, disable, optimizer_idx)
489
+ else:
490
+ return self.validation_step(inputs, disable)
491
+
492
+
493
+ if __name__ == "__main__":
494
+ from PIL import Image
495
+ import numpy as np
496
+ import torch.nn.functional as F
497
+
498
+ vae = AutoencoderKL(
499
+ embed_dim=16, ch_mult=(1, 1, 2, 2, 4),
500
+ ckpt_path='checkpoints/kl16.ckpt')
501
+
502
+ image = Image.open('data/ILSVRC2012_val_00023344.JPEG')
503
+ image = torch.from_numpy(np.array(image))
504
+ image = image.permute(2, 0, 1).float() / 255
505
+ image = 2 * image - 1
506
+
507
+ x = F.interpolate(image[None], size=(256, 256), mode='bilinear', align_corners=True)
508
+
509
+ print(x.shape)
510
+
511
+ with torch.no_grad():
512
+ z = vae.encode(x).sample()
513
+ print(z.shape)
514
+ x_rec = vae.decode(z)[0]
515
+
516
+ x_rec = (x_rec + 1.0) * 255 / 2
517
+ x_rec = torch.clamp(x_rec, min=0, max=255)
518
+ x_rec = x_rec.to(torch.uint8)
519
+
520
+ x_rec = x_rec.permute(1, 2, 0)
521
+
522
+ x_rec = Image.fromarray(x_rec.numpy())
523
+
524
+ x_rec.show()
525
+
src/models/skywork_unipic_dev.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn.modules.module import T
4
+ from mmengine.model import BaseModel
5
+ from torch.autograd.function import Function
6
+ from mmengine.logging import print_log
7
+ from xtuner.model.utils import guess_load_checkpoint
8
+ import os
9
+ #from .skywork_unipic import SkyworkUnipic
10
+ from .skywork_unipic_siglip import SkyworkUnipic
11
+ from xtuner.utils import IMAGE_TOKEN_INDEX
12
+ import torch.distributed as dist
13
+ import json
14
+ from einops import rearrange
15
+
16
+
17
+ def _load_state_dict_with_ds(module_to_load, state_dict, start_prefix="", strict=True):
18
+ try:
19
+ import deepspeed
20
+ except ImportError:
21
+ raise ImportError("deepspeed is not installed. Please install deepspeed to use this feature.")
22
+
23
+ # copy state_dict so _load_from_state_dict can modify it
24
+ metadata = getattr(state_dict, "_metadata", None)
25
+ state_dict = state_dict.copy()
26
+ if metadata is not None:
27
+ state_dict._metadata = metadata
28
+
29
+ error_msgs = []
30
+ missing_keys = []
31
+ unexpected_keys = []
32
+
33
+ def load(module: torch.nn.Module, state_dict, prefix=""):
34
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
35
+ args = (state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
36
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
37
+ # state_dict
38
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
39
+ # In sharded models, each shard has only part of the full state_dict, so only gather
40
+ # parameters that are in the current state_dict.
41
+ named_parameters = dict(
42
+ module.named_parameters(prefix=prefix[:-1], recurse=False)
43
+ )
44
+ params_to_gather = [
45
+ named_parameters[k]
46
+ for k in state_dict.keys()
47
+ if k in named_parameters
48
+ ]
49
+ if len(params_to_gather) > 0:
50
+ # because zero3 puts placeholders in model params, this context
51
+ # manager gathers (unpartitions) the params of the current layer, then loads from
52
+ # the state dict and then re-partitions them again
53
+ with deepspeed.zero.GatheredParameters(
54
+ params_to_gather, modifier_rank=0
55
+ ):
56
+ if deepspeed.comm.get_rank() == 0:
57
+ module._load_from_state_dict(*args)
58
+ else:
59
+ module._load_from_state_dict(*args)
60
+
61
+ for name, child in module._modules.items():
62
+ if child is not None:
63
+ load(child, state_dict, prefix + name + ".")
64
+
65
+ load(module_to_load, state_dict, start_prefix)
66
+ if len(missing_keys) > 0:
67
+ print_log(f"[WARNING] Missing keys: {missing_keys}")
68
+ if len(unexpected_keys) > 0:
69
+ print_log(f"[WARNING] Unexpected keys: {unexpected_keys}")
70
+ if error_msgs:
71
+ raise RuntimeError(
72
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
73
+ module_to_load.__class__.__name__, "\n\t".join(error_msgs)
74
+ )
75
+ )
76
+
77
+
78
+ class _ScaleGradient(Function):
79
+ @staticmethod
80
+ def forward(ctx, input, scale):
81
+ ctx.scale = scale
82
+ return input
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output):
86
+ return grad_output * ctx.scale, None
87
+
88
+
89
+ class SkyworkUnipicDev(SkyworkUnipic, BaseModel):
90
+ def __init__(
91
+ self,
92
+ grad_scale=0.1,
93
+ loss_weights=None,
94
+ pretrained_pth=None,
95
+ mar_path=None,
96
+ siglip_proj_path=None,
97
+ freeze_llm=False,
98
+ freeze_mar=False,
99
+ freeze_mar_decoder=False,
100
+ freeze_siglip_proj=False,
101
+ gradient_checkpointing=True,
102
+ **kwargs,
103
+ ):
104
+ if loss_weights is None:
105
+ loss_weights = {
106
+ "image2text": 0.01,
107
+ "text2image": 1.0,
108
+ "image_edit": 1.0,
109
+ "contrastive": 0.1,
110
+ }
111
+ super().__init__(**kwargs)
112
+
113
+ self.grad_scale = grad_scale
114
+ self.loss_weights = loss_weights
115
+ self.pretrained_pth = pretrained_pth
116
+ self.mar_path = mar_path
117
+ self.siglip_proj_path = siglip_proj_path
118
+
119
+ # 判断分布式 rank
120
+ rank = dist.get_rank() if dist.is_initialized() else 0
121
+
122
+ # === 加载预训练权重 ===
123
+ if pretrained_pth:
124
+ self.load_hf_weights(
125
+ skywork_unipic_ckpt=pretrained_pth,
126
+ siglip_proj_path=siglip_proj_path,
127
+ mar_path=mar_path
128
+ )
129
+
130
+ # === 冻结模块 ===
131
+ if freeze_llm:
132
+ self.llm.requires_grad_(False)
133
+ if freeze_mar:
134
+ self.mar.requires_grad_(False)
135
+ if freeze_mar_decoder:
136
+ # 仅冻结 MAR 解码器部件
137
+ for param in self.mar.decoder_embed.parameters():
138
+ param.requires_grad = False
139
+ for block in self.mar.decoder_blocks:
140
+ for param in block.parameters():
141
+ param.requires_grad = False
142
+ for param in self.mar.decoder_norm.parameters():
143
+ param.requires_grad = False
144
+ if isinstance(self.mar.decoder_pos_embed_learned, torch.nn.Parameter):
145
+ self.mar.decoder_pos_embed_learned.requires_grad = False
146
+ if isinstance(self.mar.diffusion_pos_embed_learned, torch.nn.Parameter):
147
+ self.mar.diffusion_pos_embed_learned.requires_grad = False
148
+ if freeze_siglip_proj:
149
+ self.siglip2_proj.requires_grad_(False)
150
+
151
+ # === 梯度检查点 ===
152
+ if gradient_checkpointing:
153
+ self.gradient_checkpointing_enable()
154
+ else:
155
+ self.gradient_checkpointing_disable()
156
+
157
+
158
+ def load_hf_weights(self,
159
+ skywork_unipic_ckpt: str = None,
160
+ siglip_proj_path: str = None,
161
+ mar_path: str = None):
162
+ """统一加载 SkyworkUnipic(可选) + SigLIP2 + MAR"""
163
+ device = "cpu"
164
+ state_dict = {}
165
+
166
+ def _print_load_result(module_name, missing, unexpected):
167
+ print_log(f"[INFO] Loaded {module_name}. missing={len(missing)}, unexpected={len(unexpected)}")
168
+
169
+ # === SkyworkUnipic 主模型(可选) ===
170
+ if skywork_unipic_ckpt:
171
+ print_log(f"[INFO] Loading SkyworkUnipic checkpoint from: {skywork_unipic_ckpt}")
172
+ # 加载 checkpoint(支持文件或目录)
173
+ if os.path.isfile(skywork_unipic_ckpt):
174
+ skywork_unipic_state = torch.load(skywork_unipic_ckpt, map_location=device)
175
+ else:
176
+ idx = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin.index.json")
177
+ if os.path.exists(idx):
178
+ with open(idx, 'r') as f:
179
+ index = json.load(f)
180
+ skywork_unipic_state = {}
181
+ for shard in sorted(set(index["weight_map"].values())):
182
+ shard_path = os.path.join(skywork_unipic_ckpt, shard)
183
+ skywork_unipic_state.update(torch.load(shard_path, map_location=device))
184
+ else:
185
+ bin_path = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin")
186
+ skywork_unipic_state = torch.load(bin_path, map_location=device)
187
+
188
+ # 删除 SkyworkUnipic checkpoint 中可能带的 MAR pos_embed,避免覆盖
189
+
190
+ # for key in [
191
+ # "mar.encoder_pos_embed_learned",
192
+ # "mar.decoder_pos_embed_learned",
193
+ # "mar.diffusion_pos_embed_learned"
194
+ # ]:
195
+ # if key in skywork_unipic_state:
196
+ # print_log(f"[INFO] Dropping `{key}` from SkyworkUnipic checkpoint")
197
+ # del skywork_unipic_state[key]
198
+ model_dict = self.state_dict()
199
+
200
+ filtered_checkpoint = {}
201
+ shape_mismatch_keys = []
202
+
203
+ for k, v in skywork_unipic_state.items():
204
+ if k in model_dict:
205
+ if v.shape == model_dict[k].shape:
206
+ filtered_checkpoint[k] = v
207
+ else:
208
+ shape_mismatch_keys.append((k, v.shape, model_dict[k].shape))
209
+
210
+ missing, unexpected = self.load_state_dict(filtered_checkpoint, strict=False)
211
+ # 打印不匹配的 key 及其形状
212
+ if shape_mismatch_keys:
213
+ print("以下 key 因形状不匹配被跳过:")
214
+ for k, checkpoint_shape, model_shape in shape_mismatch_keys:
215
+ print(f" - {k}:")
216
+ print(f" checkpoint 中的形状: {checkpoint_shape}")
217
+ print(f" 当前模型的形状: {model_shape}")
218
+ else:
219
+ print("所有 key 形状匹配,未跳过任何参数")
220
+
221
+ # missing, unexpected = self.load_state_dict(skywork_unipic_state, strict=False)
222
+ _print_load_result("SkyworkUnipic", missing, unexpected)
223
+ else:
224
+ print_log("[INFO] Skipping SkyworkUnipic checkpoint loading")
225
+
226
+ # === SigLIP2 权重 ===
227
+ if siglip_proj_path:
228
+ print_log(f"[INFO] Loading SigLIP2 weights from: {siglip_proj_path}")
229
+ siglip_state = torch.load(
230
+ siglip_proj_path, map_location="cpu", weights_only=False
231
+ )
232
+ # 如果 checkpoint 是 {"model": {...}}
233
+ if isinstance(siglip_state, dict) and "model" in siglip_state:
234
+ siglip_state = siglip_state["model"]
235
+ missing, unexpected = self.siglip2_proj.load_state_dict(
236
+ siglip_state, strict=False
237
+ )
238
+ _print_load_result("SigLIP2", missing, unexpected)
239
+ else:
240
+ print_log("[INFO] No SigLIP2 checkpoint provided, skipping")
241
+
242
+ # === MAR 权重 ===
243
+ if mar_path:
244
+ print_log(f"[INFO] Loading MAR weights from: {mar_path}")
245
+ mar_state = torch.load(mar_path, map_location="cpu", weights_only=False)
246
+ # 兼容 model_ema or model dict
247
+
248
+ if isinstance(mar_state, dict) and "model_ema" in mar_state:
249
+ mar_state = mar_state["model_ema"]
250
+
251
+ elif isinstance(mar_state, dict) and "model" in mar_state:
252
+ mar_state = mar_state["model"]
253
+
254
+
255
+ # 如果 key 带有 “mar.” 前缀,批量去掉
256
+ if any(k.startswith("mar.") for k in mar_state):
257
+ filtered_mar = {
258
+ k.replace("mar.", "", 1): v
259
+ for k, v in mar_state.items()
260
+ if k.startswith("mar.")
261
+ }
262
+ else:
263
+ filtered_mar = mar_state
264
+
265
+ missing, unexpected = self.mar.load_state_dict(
266
+ filtered_mar, strict=False
267
+ )
268
+ _print_load_result("MAR", missing, unexpected)
269
+ else:
270
+ print_log("[INFO] No MAR checkpoint provided, skipping")
271
+
272
+ return state_dict
273
+
274
+
275
+
276
+ def gradient_checkpointing_disable(self):
277
+ self.llm.gradient_checkpointing_disable()
278
+ self.mar.gradient_checkpointing_disable()
279
+
280
+ def gradient_checkpointing_enable(self):
281
+ self.llm.gradient_checkpointing_enable()
282
+ self.mar.gradient_checkpointing_enable()
283
+
284
+ def state_dict(self, *args, **kwargs):
285
+ state_dict = super().state_dict(*args, **kwargs)
286
+ state_dict = {k: v for k, v in state_dict.items()
287
+ if 'vae.' not in k}
288
+
289
+ return state_dict
290
+
291
+ def train(self: T, mode: bool = True) -> T:
292
+ super().train(mode=mode)
293
+ self.vae.train(mode=False)
294
+ return self
295
+
296
+ def text2image_loss(self, data_dict):
297
+ x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
298
+ x = self.encode(x) # b m n c
299
+ b, m, n, _ = x.shape
300
+ gt_latents = x.clone().detach().view(b, m*n, -1)
301
+
302
+ orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
303
+ mask = self.mar.random_masking(x.flatten(1, 2), orders)
304
+
305
+ input_ids = data_dict['input_ids'].to(self.device)
306
+ attention_mask = data_dict['attention_mask'].to(self.device)
307
+ x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids,
308
+ attention_mask=attention_mask)
309
+ z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
310
+
311
+ loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
312
+
313
+ return loss
314
+
315
+ def image2text_loss(self, data_dict):
316
+ input_ids = data_dict['input_ids'].to(self.device)
317
+ attention_mask = data_dict['attention_mask'].to(self.device)
318
+ labels = data_dict['labels'].to(self.device)
319
+
320
+ pixel_values = data_dict.get('pixel_values', None)
321
+ # print("pixel_values batch:", pixel_values.shape)
322
+ # print("input_ids batch:", input_ids.shape)
323
+ if pixel_values is None:
324
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
325
+ _, z_null = self.extract_visual_feature(
326
+ torch.zeros(1, 16, 16, self.token_embed_dim,
327
+ dtype=self.dtype, device=self.device)
328
+ )
329
+ loss_null = z_null.mean() * 0.0
330
+ print(f"No image found in this batch!", flush=True)
331
+ else:
332
+ x = pixel_values.to(dtype=self.dtype, device=self.device)
333
+ x = self.encode(x) # b m n c
334
+ _, z_enc = self.extract_visual_feature(x)
335
+
336
+ if self.grad_scale is not None:
337
+ z_enc = _ScaleGradient.apply(z_enc, self.grad_scale)
338
+
339
+ inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
340
+
341
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
342
+ IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
343
+ # print(f"IMAGE_TOKEN_INDEX: {IMAGE_TOKEN_INDEX}")
344
+ img_tokens = (torch.tensor(input_ids) == IMAGE_TOKEN_INDEX).sum().item()
345
+ # print(f"[校验日志] input_ids长度: {len('input_ids')}, 图像token出现次数: {img_tokens}\n")
346
+
347
+ inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1)
348
+ inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
349
+ input_ids[input_ids != IMAGE_TOKEN_INDEX])
350
+ loss_null = 0.0
351
+
352
+ output = self.llm_model(inputs_embeds=inputs_embeds,
353
+ attention_mask=attention_mask,
354
+ return_dict=True)
355
+
356
+ last_hidden_state = output.last_hidden_state[:, :-1]
357
+ labels = labels[:, 1:]
358
+ last_hidden_state = last_hidden_state[labels >= 0]
359
+ labels = labels[labels >= 0]
360
+ logits = self.llm.get_output_embeddings()(last_hidden_state)
361
+
362
+ loss_i2t = F.cross_entropy(input=logits, target=labels)
363
+
364
+ return loss_i2t + loss_null
365
+
366
+ # def image_edit_loss(self, data_dict):
367
+ # # 1. 图像前向:拼 batch 并编码到视觉特征
368
+ # x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) # 源图像批次,shape=[b_src, C, H, W]
369
+ # x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) # 编辑图像批次,shape=[b_edit, C, H, W]
370
+ # print_log(f"[DEBUG image_edit_loss] x_src.shape = {x_src.shape}, x.shape = {x.shape}", level="WARNING")
371
+
372
+ # # b_edit 应该 >= b_src
373
+ # assert x.shape[0] >= x_src.shape[0], \
374
+ # f"编辑批次大小 ({x.shape[0]}) 必须 >= 源图像批次大小 ({x_src.shape[0]})"
375
+
376
+ # # 拼接并一次性编码
377
+ # x_all = torch.cat([x_src, x], dim=0) # shape=[b_src + b_edit, C, H, W]
378
+ # x_all = self.encode(x_all) # shape=[b_src + b_edit, m, n, c]
379
+ # # 分割回源/编辑两部分
380
+ # x_src_enc, x_enc = x_all.split([x_src.shape[0], x.shape[0]], dim=0)
381
+ # # x_src_enc.shape=[b_src, m, n, c], x_enc.shape=[b_edit, m, n, c]
382
+
383
+ # # 2. 提取视觉特征:x_con 用于 decoder 条件,z_src 用于填充文本中的 <image> token
384
+ # x_con, z_src = self.extract_visual_feature(x_src_enc)
385
+ # if self.grad_scale is not None:
386
+ # x_con = _ScaleGradient.apply(x_con, self.grad_scale)
387
+ # z_src = _ScaleGradient.apply(z_src, self.grad_scale)
388
+ # # z_src.shape = [b_src, m*n, C]
389
+
390
+ # # 3. 文本条件分支:构造 inputs_embeds
391
+ # attention_mask = data_dict['attention_mask'].to(self.device) # shape=[b_edit, seq_len]
392
+ # input_ids = data_dict['input_ids'].to(self.device) # shape=[b_edit, seq_len]
393
+ # b_edit, seq_len = input_ids.shape
394
+ # hidden_size = self.llm.config.hidden_size
395
+
396
+ # # 先准备一个全 0 的 inputs_embeds
397
+ # inputs_embeds = z_src.new_zeros(b_edit, seq_len, hidden_size) # shape=[b_edit, seq_len, hidden_size]
398
+
399
+ # # 找到所有 <image> token 位置的 mask
400
+ # mask_imgpos = (input_ids == IMAGE_TOKEN_INDEX) # bool tensor [b_edit, seq_len]
401
+
402
+ # # 需要将单个 z_src 展开成 b_edit 份,再按 mask_imgpos 填入
403
+ # # 1) expand:把 z_src 从 [b_src, m*n, C] → [b_edit, m*n, C]
404
+ # # (一般 b_src=1,所以就是复制那一份)
405
+ # z_src_rep = z_src.expand(b_edit, -1, -1) # [b_edit, m*n, C]
406
+ # # 2) flatten:将二维展开到一维,对应 mask_imgpos.sum() 个位置
407
+ # flat_z = z_src_rep.flatten(0, 1) # [b_edit*m*n, C]
408
+
409
+ # # **重要检查**:保证 mask_imgpos 中 True 的数量 == flat_z.shape[0]
410
+ # img_tokens_count = mask_imgpos.sum().item()
411
+ # assert img_tokens_count == flat_z.shape[0], \
412
+ # f"<image> token 数 ({img_tokens_count}) 不等于视觉特征数 ({flat_z.shape[0]})"
413
+
414
+ # # 填充视觉 token 对应位置
415
+ # inputs_embeds[mask_imgpos] = flat_z
416
+
417
+ # # 剩下的位置用文本 embedding
418
+ # txt_pos = ~mask_imgpos
419
+ # txt_embeddings = self.llm.get_input_embeddings()(input_ids[txt_pos])
420
+ # inputs_embeds[txt_pos] = txt_embeddings
421
+
422
+ # # 4. MAE-style 重建分支:在 decoder 前注入 inputs_embeds 与 attention_mask
423
+ # b, m, n, c = x_enc.shape
424
+ # gt = x_enc.view(b, m*n, c) # 作为重建目标
425
+ # orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
426
+ # mask = self.mar.random_masking(x_enc.flatten(1, 2), orders)
427
+
428
+ # # 带条件的 encoder forward
429
+ # x_enc_out = self.forward_mae_encoder(
430
+ # x_enc,
431
+ # mask,
432
+ # inputs_embeds=inputs_embeds,
433
+ # attention_mask=attention_mask
434
+ # )
435
+ # # decoder 重建
436
+ # z_dec = self.mar.forward_mae_decoder(
437
+ # x_enc_out,
438
+ # mask,
439
+ # image_shape=(m, n),
440
+ # x_con=x_con
441
+ # )
442
+ # # 计算损失
443
+ # loss = self.mar.forward_loss(z=z_dec, target=gt, mask=mask)
444
+ # return loss
445
+
446
+
447
+ # def image_edit_loss_vae(self, data_dict):
448
+ # """
449
+ # 计算图像编辑任务的损失。
450
+ # 参考图(x_src)的特征直接作为条件(x_con)送入解码器,不参与编码器重建。
451
+ # 编码器(encoder)仅在目标图(x_tgt)上进行掩码重建,并接收文本和参考图的上下文信息。
452
+ # """
453
+ # # === 步骤 1: 读入数据 ===
454
+ # x_src = data_dict['pixel_values_src'].to(self.device).to(self.dtype)
455
+ # x_tgt = data_dict['pixel_values'].to(self.device).to(self.dtype)
456
+ # attention_mask = data_dict['attention_mask'].to(self.device)
457
+ # input_ids = data_dict['input_ids'].to(self.device)
458
+ # # IMG_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
459
+ # B = x_tgt.shape[0]
460
+
461
+ # # === 步骤 2: 处理参考图 (Reference Image) ===
462
+ # # VAE编码,不计算梯度
463
+ # with torch.no_grad():
464
+ # z_src_latent = self.encode(x_src) # [B, m, n, token_dim]
465
+
466
+ # # 将VAE潜变量转换为解码器条件(x_con)和LLM输入(z_src_buf)
467
+ # # 这一步实现了 "参考图潜变量 -> 解码器" 的直接通路
468
+ # x_con, z_src_buf = self.vae_latent_to_decoder_feature(z_src_latent)
469
+ # # x_con: [B, 4096, enc_dim] -> 用于解码器
470
+ # # z_src_buf: [B, 4160, llm_dim] -> 用于LLM
471
+
472
+ # # === 步骤 3: 构建LLM的输入 (inputs_embeds) ===
473
+ # # 结合文本指令(input_ids)和参考图特征(z_src_buf)
474
+ # _, T = input_ids.shape
475
+ # H_llm = self.llm.config.hidden_size
476
+ # inputs_embeds = torch.zeros(B, T, H_llm, device=self.device, dtype=z_src_buf.dtype)
477
+
478
+ # # 填充<image> token和文本token的嵌入
479
+ # inputs_embeds[input_ids == IMG_TOKEN_INDEX] = z_src_buf.flatten(0, 1)
480
+ # # input_ids 为33280
481
+ # # z_src_buf.flatten(0, 1) 为33792 为什么 会比input_ids 多512个呢?
482
+ # inputs_embeds[input_ids != IMG_TOKEN_INDEX] = self.llm.get_input_embeddings()(
483
+ # input_ids[input_ids != IMG_TOKEN_INDEX]
484
+ # )
485
+
486
+ # # === 步骤 4: 处理目标图 (Target Image) 并进行编码器前向传播 ===
487
+ # # VAE编码目标图,不计算梯度
488
+ # with torch.no_grad():
489
+ # z_tgt_latent = self.encode(x_tgt) # [B, m, n, token_dim]
490
+
491
+ # # 为目标图潜变量创建掩码(mask)以进行MAE重建
492
+ # B, m, n, token_dim = z_tgt_latent.shape
493
+ # patch_tokens_tgt = z_tgt_latent.view(B, m * n, token_dim) # 作为重建的目标
494
+ # orders = self.mar.sample_orders(bsz=B, seq_len=m * n)
495
+ # mask = self.mar.random_masking(patch_tokens_tgt, orders)
496
+
497
+ # # **核心**: 编码器只处理目标图(z_tgt_latent)的可见部分,并接收LLM的上下文
498
+ # x_enc = self.forward_mae_encoder(
499
+ # z_tgt_latent, # 目标图潜变量
500
+ # mask,
501
+ # detach=False,
502
+ # inputs_embeds=inputs_embeds, # 包含文本和参考图信息的上下文
503
+ # attention_mask=attention_mask
504
+ # )
505
+
506
+ # # === 步骤 5: 解码器重建 ===
507
+ # # 解码器使用编码器的输出(x_enc)和参考图的特征(x_con)来重建完整的潜在表示
508
+ # z_pred = self.mar.forward_mae_decoder(
509
+ # x_enc,
510
+ # mask,
511
+ # image_shape=(m, n),
512
+ # x_con=x_con # ★ 参考图特征直接作用于此
513
+ # )
514
+
515
+ # # === 步骤 6: 计算损失 ===
516
+ # loss = self.mar.forward_loss(
517
+ # z=z_pred,
518
+ # target=patch_tokens_tgt,
519
+ # mask=mask
520
+ # )
521
+ # return loss
522
+
523
+ def image_edit_loss_contrastive(self, data_dict):
524
+ # Step 1: 获取图像特征
525
+ x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device)
526
+ x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
527
+ assert len(x_src) >= len(x)
528
+ x_src, x = self.encode(torch.cat([x_src, x])).split([len(x_src), len(x)], dim=0)
529
+
530
+ # Step 2: 文本输入部分
531
+ attention_mask = data_dict['attention_mask'].to(self.device)
532
+ input_ids = data_dict['input_ids'].to(self.device)
533
+
534
+ x_con, z_src = self.extract_visual_feature(x_src)
535
+ if self.grad_scale is not None:
536
+ z_src = _ScaleGradient.apply(z_src, self.grad_scale)
537
+ x_con = _ScaleGradient.apply(x_con, self.grad_scale)
538
+
539
+ inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
540
+ # IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
541
+ inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
542
+ inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
543
+ input_ids[input_ids != IMAGE_TOKEN_INDEX]
544
+ )
545
+
546
+ # Step 3: 计算 reconstruction loss
547
+ b, m, n, _ = x.shape
548
+ gt_latents = x.clone().detach().view(b, m * n, -1)
549
+ orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
550
+ mask = self.mar.random_masking(x.flatten(1, 2), orders)
551
+ x_enc = self.forward_mae_encoder(x, mask,
552
+ inputs_embeds=inputs_embeds,
553
+ attention_mask=attention_mask)
554
+ z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con)
555
+ rec_loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
556
+
557
+ # Step 4: Contrastive loss between repeat and edit
558
+ # 假设 batch 是偶数,按 (repeat, edit) 对排列
559
+ z_src_flat = z_src.mean(dim=1) # [B, D] 全局池化
560
+ z_src_flat = F.normalize(z_src_flat, dim=-1)
561
+
562
+ repeat_z = z_src_flat[::2] # even index
563
+ edit_z = z_src_flat[1::2] # odd index
564
+
565
+ logits = torch.matmul(edit_z, repeat_z.T) / 0.07 # [B, B]
566
+ labels = torch.arange(logits.size(0), device=logits.device)
567
+ contrastive_loss = F.cross_entropy(logits, labels)
568
+
569
+ return rec_loss + self.loss_weights.get("contrastive") * contrastive_loss
570
+
571
+ def image_edit_loss(self, data_dict):
572
+ # Multi-turn editing is also supported
573
+ x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device)
574
+ x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
575
+ # print_log(f"[DEBUG] x_src.shape = {x_src.shape}, x.shape = {x.shape}")
576
+
577
+ # assert len(x_src) >= len(x)
578
+ x_cat = torch.cat([x_src, x], dim=0)
579
+ x_src, x = self.encode(x_cat).split([len(x_src), len(x)], dim=0)
580
+
581
+ # Prepare context, including source images and instructions
582
+ attention_mask = data_dict['attention_mask'].to(self.device)
583
+ input_ids = data_dict['input_ids'].to(self.device)
584
+
585
+ x_con, z_src = self.extract_visual_feature(x_src)
586
+ if self.grad_scale is not None:
587
+ z_src = _ScaleGradient.apply(z_src, self.grad_scale)
588
+ x_con = _ScaleGradient.apply(x_con, self.grad_scale)
589
+
590
+ inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
591
+
592
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
593
+
594
+ IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
595
+ # print("tokenizer idx in skywork_unipic_dev=", self.tokenizer.convert_tokens_to_ids("<image>"))
596
+
597
+ inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
598
+ inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
599
+ input_ids[input_ids != IMAGE_TOKEN_INDEX]
600
+ )
601
+
602
+ # --------------------------------------------------
603
+ # 3. MAE-style 重建
604
+ # --------------------------------------------------
605
+
606
+ b, m, n, _ = x.shape
607
+ gt_latents = x.clone().detach().view(b, m * n, -1)
608
+ orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
609
+ mask = self.mar.random_masking(x.flatten(1, 2), orders)
610
+ x_enc = self.forward_mae_encoder(x, mask,
611
+ inputs_embeds=inputs_embeds,
612
+ attention_mask=attention_mask)
613
+ z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con)
614
+
615
+ loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
616
+ return loss
617
+
618
+
619
+
620
+ def forward(self, data, data_samples=None, mode='loss'):
621
+ if mode == 'loss':
622
+ return self.compute_loss(data_dict=data)
623
+ else:
624
+ raise NotImplementedError
625
+
626
+ def compute_loss(self, data_dict):
627
+ losses = {}
628
+ for data_type, batch_data in data_dict.items():
629
+ if 'text2image' in data_type:
630
+ loss = self.text2image_loss(batch_data)
631
+ elif 'image2text' in data_type:
632
+ loss = self.image2text_loss(batch_data)
633
+ elif 'image_edit' in data_type:
634
+ loss = self.image_edit_loss(batch_data)
635
+ else:
636
+ raise NotImplementedError(f"Unknown data_type: {data_type}")
637
+ weight = self.loss_weights.get(data_type, 1.0)
638
+ losses[f'loss_{data_type}'] = loss * weight
639
+ return losses
640
+
641
+
642
+
643
+
644
+
645
+
src/models/skywork_unipic_ori.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import contextlib
6
+ from einops import rearrange
7
+ from transformers.cache_utils import DynamicCache
8
+ from src.builder import BUILDER
9
+ from tqdm import tqdm
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from transformers.integrations.deepspeed import (
12
+ is_deepspeed_zero3_enabled,
13
+ set_hf_deepspeed_config,
14
+ unset_hf_deepspeed_config,
15
+ deepspeed_config
16
+ )
17
+
18
+ @contextlib.contextmanager
19
+ def temporarily_disable_deepspeed_zero3():
20
+ if is_deepspeed_zero3_enabled():
21
+ config = deepspeed_config()
22
+ print(f'[DEBUG] ds config={config}')
23
+ unset_hf_deepspeed_config()
24
+ yield
25
+ set_hf_deepspeed_config(config)
26
+ else:
27
+ yield
28
+
29
+
30
+
31
+ def build_mlp(hidden_size, projector_dim, z_dim):
32
+ return nn.Sequential(
33
+ nn.Linear(hidden_size, projector_dim),
34
+ nn.SiLU(),
35
+ nn.Linear(projector_dim, z_dim),)
36
+
37
+
38
+ def mask_by_order(mask_len, order, bsz, seq_len):
39
+ masking = torch.zeros(bsz, seq_len, device=order.device)
40
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
41
+ src=torch.ones(bsz, seq_len, device=order.device)).bool()
42
+ return masking
43
+
44
+
45
+ class SkyworkUnipic(nn.Module):
46
+ def __init__(self,
47
+ vae,
48
+ vae_scale,
49
+ llm,
50
+ mar,
51
+ tokenizer,
52
+ prompt_template):
53
+ super().__init__()
54
+ with temporarily_disable_deepspeed_zero3():
55
+ # VAE
56
+ self.vae = BUILDER.build(vae)
57
+ self.vae.requires_grad_(False)
58
+ self.vae_scale = vae_scale
59
+
60
+ # LLM
61
+ self.llm = BUILDER.build(llm)
62
+ self.tokenizer = BUILDER.build(tokenizer)
63
+ self.prompt_template = prompt_template
64
+
65
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
66
+ image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
67
+ print(f"Registered <image> token at index {image_token_idx}")
68
+
69
+ # MAR
70
+ self.mar = BUILDER.build(mar)
71
+ # projection layers
72
+ self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim,
73
+ projector_dim=self.llm.config.hidden_size,
74
+ z_dim=self.llm.config.hidden_size)
75
+ self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size,
76
+ projector_dim=self.llm.config.hidden_size,
77
+ z_dim=self.mar.encoder_embed_dim)
78
+
79
+
80
+ @property
81
+ def llm_model(self):
82
+ return self.llm.model
83
+
84
+ @property
85
+ def device(self):
86
+ return self.llm.device
87
+
88
+ @property
89
+ def dtype(self):
90
+ return self.llm.dtype
91
+
92
+ @property
93
+ def gen_seq_len(self):
94
+ return self.mar.seq_len
95
+
96
+ @property
97
+ def token_embed_dim(self):
98
+ return self.vae.embed_dim * (self.mar.patch_size ** 2)
99
+
100
+ @torch.no_grad()
101
+ def encode(self, x):
102
+ posterior = self.vae.encode(x)
103
+ z = posterior.mode().mul_(self.vae_scale)
104
+ z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)',
105
+ p=self.mar.patch_size, q=self.mar.patch_size)
106
+
107
+ return z
108
+
109
+ @torch.no_grad()
110
+ def decode(self, z):
111
+ z /= self.vae_scale
112
+ z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)',
113
+ p=self.mar.patch_size, q=self.mar.patch_size)
114
+
115
+ x = self.vae.decode(z)
116
+ return x
117
+
118
+ def prepare_forward_input(self,
119
+ x,
120
+ inputs_embeds=None,
121
+ input_ids=None,
122
+ attention_mask=None,
123
+ past_key_values=None):
124
+ b, l, _ = x.shape
125
+ attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
126
+ attention_mask = torch.cat([
127
+ attention_mask, attention_mask.new_ones(b, l)
128
+ ], dim=1)
129
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
130
+ position_ids[position_ids < 0] = 0
131
+
132
+ # import pdb; pdb.set_trace()
133
+
134
+ # prepare context
135
+ if past_key_values is not None:
136
+ inputs_embeds = x
137
+ position_ids = position_ids[:, -l:]
138
+ else:
139
+ if inputs_embeds is None:
140
+ input_ids = input_ids.to(self.device)
141
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
142
+ inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
143
+
144
+ return dict(inputs_embeds=inputs_embeds,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ past_key_values=past_key_values)
148
+
149
+ def extract_visual_feature(self, x, mask=None, detach=False):
150
+ b, m, n, _ = x.shape
151
+ x = x.view(b, m*n, -1)
152
+ # x: b mn c
153
+ if mask is None:
154
+ mask = torch.zeros_like(x[..., 0])
155
+ null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
156
+ x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
157
+
158
+ z_enc = self.proj_in(x_enc)
159
+ # Move buffers to the end of the image sequence
160
+ z_enc = torch.cat([
161
+ z_enc[:, self.mar.buffer_size:],
162
+ z_enc[:, :self.mar.buffer_size]], dim=1)
163
+
164
+ if detach:
165
+ x_enc = x_enc.detach()
166
+ z_enc = z_enc.detach()
167
+
168
+ return x_enc, z_enc
169
+
170
+ def vae_latent_to_decoder_feature(self, z_src_latent):
171
+ """
172
+ Returns:
173
+ x_con [B, buf_sz + m*n, enc_dim] for the MAE decoder
174
+ z_src_buf [B, buf_sz + m*n, llm_dim] to scatter into <image> tokens
175
+ """
176
+ B, m, n, token_dim = z_src_latent.shape
177
+ num_patches = m * n
178
+ enc_dim = self.mar.encoder_embed_dim # e.g. 1280
179
+ llm_dim = self.llm.config.hidden_size # e.g. 1536
180
+ buf_sz = self.mar.buffer_size # e.g. 64
181
+
182
+ # 1) flatten patches → [B,4096,token_dim]
183
+ patch_tokens = z_src_latent.view(B, num_patches, token_dim)
184
+
185
+ # 2) project to encoder dim → [B,4096,enc_dim]
186
+ z_enc = self.mar.z_proj(patch_tokens)
187
+ z_enc = self.mar.z_proj_ln(z_enc)
188
+
189
+ # (optional) add encoder pos embed for image part only
190
+ full_pos = self.mar.get_encoder_pos_embed(h=m, w=n) # [1, buf_sz+4096, enc_dim]
191
+ pos_img = full_pos[:, buf_sz:] # [1,4096,enc_dim]
192
+ z_enc = z_enc + pos_img
193
+
194
+ # 3) build x_con for MAE decoder: **one** buffer pad + image tokens
195
+ buf_enc = torch.zeros(B, buf_sz, enc_dim,
196
+ device=z_enc.device, dtype=z_enc.dtype)
197
+ x_con = torch.cat([buf_enc, z_enc], dim=1) # [B,4160,enc_dim]
198
+
199
+ # 4) build z_src_buf for LLM: **project the exact same** x_con, then rotate buffer→end
200
+ z_proj_llm = self.proj_in(x_con) # [B,4160,llm_dim]
201
+ # rotate: take image portion then buffer portion
202
+ z_src_buf = torch.cat([
203
+ z_proj_llm[:, buf_sz:], # [B,4096,llm_dim]
204
+ z_proj_llm[:, :buf_sz] # [B, 64,llm_dim]
205
+ ], dim=1) # [B,4160,llm_dim]
206
+
207
+ return x_con, z_src_buf
208
+
209
+ def forward_mae_encoder(self, x, mask, detach=False, **context):
210
+ b, m, n, _ = x.shape
211
+ x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
212
+ inputs = self.prepare_forward_input(x=z_enc, **context)
213
+ output = self.llm_model(**inputs, return_dict=True)
214
+
215
+ z_llm = output.last_hidden_state[:, -z_enc.shape[1]:]
216
+
217
+ # move buffers back to the start of the image sequence
218
+ z_llm = torch.cat([
219
+ z_llm[:, -self.mar.buffer_size:],
220
+ z_llm[:, :-self.mar.buffer_size]], dim=1)
221
+
222
+ # residual learning
223
+ x_enc = x_enc + self.proj_out(z_llm)
224
+
225
+ return x_enc
226
+
227
+ @staticmethod
228
+ def curtail_cache(past_key_values, cur_len):
229
+ for past_key_values_ in past_key_values:
230
+ keys, values = past_key_values_
231
+ keys.data = keys.data[:, :, :cur_len]
232
+ values.data = values.data[:, :, :cur_len]
233
+
234
+ @torch.no_grad()
235
+ def prepare_text_conditions(self, prompt, cfg_prompt='Generate an image.'):
236
+ all_prompts = [self.prompt_template['INSTRUCTION'].format(input=prompt),
237
+ self.prompt_template['INSTRUCTION'].format(input=cfg_prompt)]
238
+
239
+
240
+ input_ids = [self.tokenizer.encode(p, add_special_tokens=True, return_tensors='pt')[0]
241
+ for p in all_prompts]
242
+ valid_lens = [len(input_ids_) for input_ids_ in input_ids]
243
+ input_ids = pad_sequence(input_ids, batch_first=True,
244
+ padding_value=self.tokenizer.eos_token_id)
245
+ attention_mask = torch.zeros_like(input_ids).bool()
246
+ for i in range(len(input_ids)):
247
+ attention_mask[i, :valid_lens[i]] = True
248
+
249
+ return dict(input_ids=input_ids.to(self.device),
250
+ attention_mask=attention_mask.to(self.device))
251
+
252
+ @torch.no_grad()
253
+ def sample(self,
254
+ input_ids=None, inputs_embeds=None,
255
+ attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0,
256
+ progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs):
257
+ if inputs_embeds is None and input_ids is not None:
258
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
259
+
260
+ bsz = attention_mask.shape[0]
261
+ if cfg != 1.0:
262
+ assert bsz % 2 == 0
263
+
264
+ if image_shape is None:
265
+ m = n = int(self.gen_seq_len ** 0.5)
266
+ else:
267
+ m, n = image_shape
268
+
269
+ if mask is None:
270
+ mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype)
271
+ else:
272
+ mask = mask.view(bsz, m*n)
273
+ tokens = torch.zeros(bsz, m*n, self.token_embed_dim,
274
+ device=self.device, dtype=self.dtype)
275
+ orders = self.mar.sample_orders(bsz, seq_len=m*n)
276
+ if cfg != 1.0:
277
+ orders[bsz//2:] = orders[:bsz//2]
278
+
279
+ indices = list(range(num_iter))
280
+ if progress:
281
+ indices = tqdm(indices)
282
+
283
+ # past key values can be prepared outside (usually in multi-turn editing)
284
+ if past_key_values is None:
285
+ output = self.llm_model(inputs_embeds=inputs_embeds,
286
+ attention_mask=None,
287
+ position_ids=None,
288
+ past_key_values=DynamicCache.from_legacy_cache(),
289
+ return_dict=True,
290
+ use_cache=True)
291
+ past_key_values = output.past_key_values
292
+
293
+ # generate latents
294
+ for step in indices:
295
+ cur_tokens = tokens.clone()
296
+ x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1),
297
+ mask.to(self.dtype),
298
+ past_key_values=past_key_values,
299
+ # inputs_embeds=inputs_embeds,
300
+ attention_mask=attention_mask)
301
+ # import pdb; pdb.set_trace()
302
+ self.curtail_cache(past_key_values, inputs_embeds.shape[1])
303
+ # import pdb; pdb.set_trace()
304
+
305
+ z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con)
306
+
307
+ # mask ratio for the next round, following MaskGIT and MAGE.
308
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
309
+ mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device)
310
+
311
+ # masks out at least one for the next iteration
312
+ mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
313
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
314
+
315
+ # get masking for next iteration and locations to be predicted in this iteration
316
+ mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device)
317
+ if cfg != 1.0:
318
+ mask_next[bsz//2:] = mask_next[:bsz//2]
319
+ if step >= num_iter - 1:
320
+ mask_to_pred = mask[:bsz].bool()
321
+ else:
322
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
323
+ mask = mask_next
324
+ # if not cfg == 1.0:
325
+ # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
326
+
327
+ # sample token latents for this step
328
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
329
+ # cfg schedule follow Muse
330
+ if cfg_schedule == "linear":
331
+ cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n)
332
+ elif cfg_schedule == "constant":
333
+ cfg_iter = cfg
334
+ else:
335
+ raise NotImplementedError
336
+ sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype)
337
+ # if not cfg == 1.0:
338
+ # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
339
+ # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
340
+
341
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
342
+ if cfg != 1.0:
343
+ cur_tokens[bsz//2:] = cur_tokens[:bsz//2]
344
+ tokens = cur_tokens.clone()
345
+
346
+ pred = self.decode(tokens.view(bsz, m, n, -1))
347
+
348
+ if cfg != 1.0:
349
+ pred = pred[:bsz//2]
350
+ return pred
src/models/skywork_unipic_siglip.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from transformers.cache_utils import DynamicCache
7
+ from src.builder import BUILDER
8
+ from tqdm import tqdm
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+
12
+ def build_mlp(hidden_size, projector_dim, z_dim):
13
+ return nn.Sequential(
14
+ nn.Linear(hidden_size, projector_dim),
15
+ nn.SiLU(),
16
+ nn.Linear(projector_dim, z_dim),
17
+ )
18
+
19
+
20
+ def mask_by_order(mask_len, order, bsz, seq_len):
21
+ masking = torch.zeros(bsz, seq_len, device=order.device)
22
+ masking = torch.scatter(
23
+ masking,
24
+ dim=-1,
25
+ index=order[:, : mask_len.long()],
26
+ src=torch.ones(bsz, seq_len, device=order.device),
27
+ ).bool()
28
+ return masking
29
+
30
+
31
+ class SkyworkUnipic(nn.Module):
32
+ def __init__(self, vae, vae_scale, llm, mar, tokenizer, prompt_template, siglip2):
33
+ super().__init__()
34
+ # VAE
35
+ self.vae = BUILDER.build(vae)
36
+ self.vae.requires_grad_(False)
37
+ self.vae_scale = vae_scale
38
+
39
+ # LLM
40
+ self.llm = BUILDER.build(llm)
41
+ self.tokenizer = BUILDER.build(tokenizer)
42
+ self.tokenizer.add_tokens(["<image>"], special_tokens=True)
43
+ self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
44
+
45
+ self.prompt_template = prompt_template
46
+
47
+ # MAR
48
+ self.mar = BUILDER.build(mar)
49
+ # projection layers
50
+ self.proj_in = build_mlp(
51
+ hidden_size=self.mar.encoder_embed_dim,
52
+ projector_dim=self.llm.config.hidden_size,
53
+ z_dim=self.llm.config.hidden_size,
54
+ )
55
+ self.proj_out = build_mlp(
56
+ hidden_size=self.llm.config.hidden_size,
57
+ projector_dim=self.llm.config.hidden_size,
58
+ z_dim=self.mar.encoder_embed_dim,
59
+ )
60
+
61
+ # siglip
62
+ self.siglip2 = BUILDER.build(siglip2)
63
+ self.siglip2_proj = build_mlp(
64
+ hidden_size=1152,
65
+ projector_dim=self.llm.config.hidden_size,
66
+ z_dim=self.llm.config.hidden_size,
67
+ )
68
+
69
+ @property
70
+ def llm_model(self):
71
+ return self.llm.model
72
+
73
+ @property
74
+ def device(self):
75
+ return self.llm.device
76
+
77
+ @property
78
+ def dtype(self):
79
+ return self.llm.dtype
80
+
81
+ @property
82
+ def gen_seq_len(self):
83
+ return self.mar.seq_len
84
+
85
+ @property
86
+ def token_embed_dim(self):
87
+ return self.vae.embed_dim * (self.mar.patch_size**2)
88
+
89
+ @torch.no_grad()
90
+ def encode(self, x):
91
+ posterior = self.vae.encode(x)
92
+ z = posterior.mode().mul_(self.vae_scale)
93
+ z = rearrange(
94
+ z,
95
+ "b c (m p) (n q) -> b m n (c p q)",
96
+ p=self.mar.patch_size,
97
+ q=self.mar.patch_size,
98
+ )
99
+
100
+ return z
101
+
102
+ @torch.no_grad()
103
+ def decode(self, z):
104
+ z /= self.vae_scale
105
+ z = rearrange(
106
+ z,
107
+ "b m n (c p q) -> b c (m p) (n q)",
108
+ p=self.mar.patch_size,
109
+ q=self.mar.patch_size,
110
+ )
111
+
112
+ x = self.vae.decode(z)
113
+ return x
114
+
115
+ def prepare_forward_input(
116
+ self,
117
+ x,
118
+ inputs_embeds=None,
119
+ input_ids=None,
120
+ attention_mask=None,
121
+ past_key_values=None,
122
+ ):
123
+ b, l, _ = x.shape
124
+ attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
125
+ attention_mask = torch.cat(
126
+ [attention_mask, attention_mask.new_ones(b, l)], dim=1
127
+ )
128
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
129
+ position_ids[position_ids < 0] = 0
130
+
131
+ # import pdb; pdb.set_trace()
132
+
133
+ # prepare context
134
+ if past_key_values is not None:
135
+ inputs_embeds = x
136
+ position_ids = position_ids[:, -l:]
137
+ else:
138
+ if inputs_embeds is None:
139
+ input_ids = input_ids.to(self.device)
140
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
141
+ inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
142
+
143
+ return dict(
144
+ inputs_embeds=inputs_embeds,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ past_key_values=past_key_values,
148
+ )
149
+
150
+ def extract_visual_feature(self, x, mask=None, detach=False):
151
+ b, m, n, _ = x.shape
152
+ x = x.view(b, m * n, -1)
153
+ # x: b mn c
154
+ if mask is None:
155
+ mask = torch.zeros_like(x[..., 0])
156
+ null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
157
+ x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
158
+
159
+ z_enc = self.proj_in(x_enc)
160
+ # Move buffers to the end of the image sequence
161
+ z_enc = torch.cat(
162
+ [z_enc[:, self.mar.buffer_size :], z_enc[:, : self.mar.buffer_size]], dim=1
163
+ )
164
+
165
+ if detach:
166
+ x_enc = x_enc.detach()
167
+ z_enc = z_enc.detach()
168
+
169
+ return x_enc, z_enc
170
+
171
+ def forward_mae_encoder(self, x, mask, detach=False, **context):
172
+ b, m, n, _ = x.shape
173
+ x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
174
+ inputs = self.prepare_forward_input(x=z_enc, **context)
175
+ output = self.llm_model(**inputs, return_dict=True)
176
+
177
+ z_llm = output.last_hidden_state[:, -z_enc.shape[1] :]
178
+
179
+ # move buffers back to the start of the image sequence
180
+ z_llm = torch.cat(
181
+ [z_llm[:, -self.mar.buffer_size :], z_llm[:, : -self.mar.buffer_size]],
182
+ dim=1,
183
+ )
184
+
185
+ # residual learning
186
+ x_enc = x_enc + self.proj_out(z_llm)
187
+
188
+ return x_enc
189
+
190
+ @staticmethod
191
+ def curtail_cache(past_key_values, cur_len):
192
+ for past_key_values_ in past_key_values:
193
+ keys, values = past_key_values_
194
+ keys.data = keys.data[:, :, :cur_len]
195
+ values.data = values.data[:, :, :cur_len]
196
+
197
+ @torch.no_grad()
198
+ def prepare_text_conditions(self, prompt, cfg_prompt="Generate an image."):
199
+ all_prompts = [
200
+ self.prompt_template["INSTRUCTION"].format(input=prompt),
201
+ self.prompt_template["INSTRUCTION"].format(input=cfg_prompt),
202
+ ]
203
+
204
+ input_ids = [
205
+ self.tokenizer.encode(p, add_special_tokens=True, return_tensors="pt")[0]
206
+ for p in all_prompts
207
+ ]
208
+ valid_lens = [len(input_ids_) for input_ids_ in input_ids]
209
+ input_ids = pad_sequence(
210
+ input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id
211
+ )
212
+ attention_mask = torch.zeros_like(input_ids).bool()
213
+ for i in range(len(input_ids)):
214
+ attention_mask[i, : valid_lens[i]] = True
215
+
216
+ return dict(
217
+ input_ids=input_ids.to(self.device),
218
+ attention_mask=attention_mask.to(self.device),
219
+ )
220
+
221
+ @torch.no_grad()
222
+ def sample(
223
+ self,
224
+ input_ids=None,
225
+ inputs_embeds=None,
226
+ attention_mask=None,
227
+ num_iter=64,
228
+ cfg=1.0,
229
+ cfg_schedule="constant",
230
+ temperature=1.0,
231
+ progress=False,
232
+ mask=None,
233
+ past_key_values=None,
234
+ image_shape=None,
235
+ x_con=None,
236
+ **kwargs,
237
+ ):
238
+ if inputs_embeds is None and input_ids is not None:
239
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
240
+
241
+ bsz = attention_mask.shape[0]
242
+ if cfg != 1.0:
243
+ assert bsz % 2 == 0
244
+
245
+ if image_shape is None:
246
+ m = n = int(self.gen_seq_len**0.5)
247
+ else:
248
+ m, n = image_shape
249
+
250
+ if mask is None:
251
+ mask = torch.ones(bsz, m * n, device=self.device, dtype=self.dtype)
252
+ else:
253
+ mask = mask.view(bsz, m * n)
254
+ tokens = torch.zeros(
255
+ bsz, m * n, self.token_embed_dim, device=self.device, dtype=self.dtype
256
+ )
257
+ orders = self.mar.sample_orders(bsz, seq_len=m * n)
258
+ if cfg != 1.0:
259
+ orders[bsz // 2 :] = orders[: bsz // 2]
260
+
261
+ indices = list(range(num_iter))
262
+ if progress:
263
+ indices = tqdm(indices)
264
+
265
+ # past key values can be prepared outside (usually in multi-turn editing)
266
+ if past_key_values is None:
267
+ output = self.llm_model(
268
+ inputs_embeds=inputs_embeds,
269
+ attention_mask=None,
270
+ position_ids=None,
271
+ past_key_values=DynamicCache.from_legacy_cache(),
272
+ return_dict=True,
273
+ use_cache=True,
274
+ )
275
+ past_key_values = output.past_key_values
276
+
277
+ # generate latents
278
+ for step in indices:
279
+ cur_tokens = tokens.clone()
280
+ x_enc = self.forward_mae_encoder(
281
+ tokens.view(bsz, m, n, -1),
282
+ mask.to(self.dtype),
283
+ past_key_values=past_key_values,
284
+ # inputs_embeds=inputs_embeds,
285
+ attention_mask=attention_mask,
286
+ )
287
+ # import pdb; pdb.set_trace()
288
+ self.curtail_cache(past_key_values, inputs_embeds.shape[1])
289
+ # import pdb; pdb.set_trace()
290
+
291
+ z = self.mar.forward_mae_decoder(
292
+ x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con
293
+ )
294
+
295
+ # mask ratio for the next round, following MaskGIT and MAGE.
296
+ mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter)
297
+ mask_len = torch.Tensor([np.floor(m * n * mask_ratio)]).to(self.device)
298
+
299
+ # masks out at least one for the next iteration
300
+ mask_len = torch.maximum(
301
+ torch.Tensor([1]).to(self.device),
302
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len),
303
+ )
304
+
305
+ # get masking for next iteration and locations to be predicted in this iteration
306
+ mask_next = mask_by_order(mask_len[0], orders, bsz, m * n).to(self.device)
307
+ if cfg != 1.0:
308
+ mask_next[bsz // 2 :] = mask_next[: bsz // 2]
309
+ if step >= num_iter - 1:
310
+ mask_to_pred = mask[:bsz].bool()
311
+ else:
312
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
313
+ mask = mask_next
314
+ # if not cfg == 1.0:
315
+ # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
316
+
317
+ # sample token latents for this step
318
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
319
+ # cfg schedule follow Muse
320
+ if cfg_schedule == "linear":
321
+ cfg_iter = 1 + (cfg - 1) * (m * n - mask_len[0]) / (m * n)
322
+ elif cfg_schedule == "constant":
323
+ cfg_iter = cfg
324
+ else:
325
+ raise NotImplementedError
326
+ sampled_token_latent = self.mar.diffloss.sample(
327
+ z, temperature, cfg_iter
328
+ ).to(self.dtype)
329
+ # if not cfg == 1.0:
330
+ # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
331
+ # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
332
+
333
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
334
+ if cfg != 1.0:
335
+ cur_tokens[bsz // 2 :] = cur_tokens[: bsz // 2]
336
+ tokens = cur_tokens.clone()
337
+
338
+ pred = self.decode(tokens.view(bsz, m, n, -1))
339
+
340
+ if cfg != 1.0:
341
+ pred = pred[: bsz // 2]
342
+ return pred
src/optimisers/constructor.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Union
4
+ from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
5
+ from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
6
+ OPTIMIZERS)
7
+
8
+
9
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
10
+ decay = []
11
+ no_decay = []
12
+ for name, param in model.named_parameters():
13
+ if not param.requires_grad:
14
+ continue # frozen weights
15
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
16
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
17
+ else:
18
+ decay.append(param)
19
+
20
+ num_decay_params = sum(p.numel() for p in decay)
21
+ num_nodecay_params = sum(p.numel() for p in no_decay)
22
+ print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters")
23
+ print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters")
24
+
25
+ return [
26
+ {'params': no_decay, 'weight_decay': 0.},
27
+ {'params': decay, 'weight_decay': weight_decay}]
28
+
29
+
30
+ class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor):
31
+ def __call__(self, model: nn.Module) -> OptimWrapper:
32
+ if hasattr(model, 'module'):
33
+ model = model.module
34
+
35
+ optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
36
+ optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
37
+ optimizer_cfg = self.optimizer_cfg.copy()
38
+ optimizer_cls = self.optimizer_cfg['type']
39
+ # Optimizer like HybridAdam in colossalai requires the argument name
40
+ # `model_params` rather than `params`. Here we get the first argument
41
+ # name and fill it with the model parameters.
42
+ if isinstance(optimizer_cls, str):
43
+ with OPTIMIZERS.switch_scope_and_registry(None) as registry:
44
+ optimizer_cls = registry.get(self.optimizer_cfg['type'])
45
+ fisrt_arg_name = next(
46
+ iter(inspect.signature(optimizer_cls).parameters))
47
+ # import pdb; pdb.set_trace()
48
+ param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0))
49
+ optimizer_cfg[fisrt_arg_name] = param_groups
50
+ optimizer = OPTIMIZERS.build(optimizer_cfg)
51
+
52
+ # # if no paramwise option is specified, just use the global setting
53
+ # if not self.paramwise_cfg:
54
+ # optimizer_cfg[fisrt_arg_name] = model.parameters()
55
+ # optimizer = OPTIMIZERS.build(optimizer_cfg)
56
+ # else:
57
+ # # set param-wise lr and weight decay recursively
58
+ # params: List = []
59
+ # self.add_params(params, model)
60
+ # optimizer_cfg[fisrt_arg_name] = params
61
+ # optimizer = OPTIMIZERS.build(optimizer_cfg)
62
+ optim_wrapper = OPTIM_WRAPPERS.build(
63
+ optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
64
+ return optim_wrapper
src/optimisers/custom_adamw.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from torch.optim import AdamW
3
+
4
+
5
+ class CustomAdamW(AdamW):
6
+ def __init__(self, params, weight_decay, *args, **kwargs):
7
+ import pdb; pdb.set_trace()
8
+ if isinstance(params, dict):
9
+ params = [p for p in params.values() if p.requires_grad]
10
+ else:
11
+ params = [p for p in params if p.requires_grad]
12
+
13
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
14
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
15
+ decay_params = [p for p in params if p.dim() >= 2]
16
+ nodecay_params = [p for p in params if p.dim() < 2]
17
+ optim_groups = [
18
+ {'params': decay_params, 'weight_decay': weight_decay},
19
+ {'params': nodecay_params, 'weight_decay': 0.0}
20
+ ]
21
+ num_decay_params = sum(p.numel() for p in decay_params)
22
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
23
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
24
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
25
+ # Create AdamW optimizer and use the fused version if it is available
26
+ # fused_available = 'fused' in inspect.signature(AdamW).parameters
27
+ # extra_args = dict(fused=True) if fused_available else dict()
28
+ # print(f"using fused AdamW: {fused_available}")
29
+
30
+ # kwargs.update(extra_args)
31
+
32
+ super().__init__(params=optim_groups, *args, **kwargs)
src/runners/custom_runner.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import inspect
4
+
5
+ from torch.utils.data import DataLoader
6
+ from functools import partial
7
+ from typing import Callable, Dict, List, Optional, Union
8
+
9
+ from mmengine.logging import print_log
10
+ from mmengine.dist import get_rank
11
+ from mmengine.dataset import worker_init_fn as default_worker_init_fn
12
+ from mmengine.utils import digit_version
13
+ from mmengine.utils.dl_utils import TORCH_VERSION
14
+ from mmengine.runner import FlexibleRunner
15
+ from mmengine.registry import (
16
+ DATA_SAMPLERS,
17
+ DATASETS,
18
+ FUNCTIONS,
19
+ )
20
+ from xtuner.registry import BUILDER
21
+
22
+
23
+ def clean_concatdataset_fields(cfg):
24
+ """
25
+ 递归清除所有 ConcatDataset 配置中的非法字段(如 image_size)
26
+ """
27
+ if isinstance(cfg, dict):
28
+ # 如果是 ConcatDataset 层,清除非法字段
29
+ if cfg.get('type') == "ConcatDataset":
30
+ for key in ['image_size']:
31
+ if key in cfg:
32
+ del cfg[key]
33
+
34
+ # 递归处理子字段
35
+ for k, v in cfg.items():
36
+ clean_concatdataset_fields(v)
37
+
38
+ elif isinstance(cfg, list):
39
+ for item in cfg:
40
+ clean_concatdataset_fields(item)
41
+
42
+ return cfg
43
+
44
+
45
+
46
+ class CustomRunner(FlexibleRunner):
47
+ def __init__(
48
+ self,
49
+ **kwargs,
50
+ ):
51
+ super().__init__(**kwargs)
52
+
53
+ @staticmethod
54
+ def build_dataloader(
55
+ dataloader: Union[DataLoader, Dict],
56
+ seed: Optional[int] = None,
57
+ diff_rank_seed: bool = False,
58
+ ) -> DataLoader:
59
+ """Build dataloader.
60
+
61
+ The method builds three components:
62
+
63
+ - Dataset
64
+ - Sampler
65
+ - Dataloader
66
+
67
+ An example of ``dataloader``::
68
+
69
+ dataloader = dict(
70
+ dataset=dict(type='ToyDataset'),
71
+ sampler=dict(type='DefaultSampler', shuffle=True),
72
+ batch_size=1,
73
+ num_workers=9
74
+ )
75
+
76
+ Args:
77
+ dataloader (DataLoader or dict): A Dataloader object or a dict to
78
+ build Dataloader object. If ``dataloader`` is a Dataloader
79
+ object, just returns itself.
80
+ seed (int, optional): Random seed. Defaults to None.
81
+ diff_rank_seed (bool): Whether or not set different seeds to
82
+ different ranks. If True, the seed passed to sampler is set
83
+ to None, in order to synchronize the seeds used in samplers
84
+ across different ranks. Defaults to False.
85
+
86
+ Returns:
87
+ Dataloader: DataLoader build from ``dataloader_cfg``.
88
+ """
89
+ if isinstance(dataloader, DataLoader):
90
+ return dataloader
91
+
92
+ dataloader_cfg = copy.deepcopy(dataloader)
93
+
94
+ clean_concatdataset_fields(dataloader_cfg)
95
+
96
+ # build dataset
97
+ dataset_cfg = dataloader_cfg.pop('dataset')
98
+ if isinstance(dataset_cfg, dict):
99
+ dataset = DATASETS.build(dataset_cfg)
100
+ if hasattr(dataset, 'full_init'):
101
+ dataset.full_init()
102
+ else:
103
+ # fallback to raise error in dataloader
104
+ # if `dataset_cfg` is not a valid type
105
+ dataset = dataset_cfg
106
+
107
+ # build sampler
108
+ sampler_cfg = dataloader_cfg.pop('sampler')
109
+ if isinstance(sampler_cfg, dict):
110
+ sampler_seed = None if diff_rank_seed else seed
111
+ sampler = DATA_SAMPLERS.build(
112
+ sampler_cfg,
113
+ default_args=dict(dataset=dataset, seed=sampler_seed))
114
+ else:
115
+ # fallback to raise error in dataloader
116
+ # if `sampler_cfg` is not a valid type
117
+ sampler = sampler_cfg
118
+
119
+ # build batch sampler
120
+ batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
121
+ if batch_sampler_cfg is None:
122
+ batch_sampler = None
123
+ elif isinstance(batch_sampler_cfg, dict):
124
+ batch_sampler = DATA_SAMPLERS.build(
125
+ batch_sampler_cfg,
126
+ default_args=dict(
127
+ dataset=dataset,
128
+ sampler=sampler,
129
+ batch_size=dataloader_cfg.pop('batch_size')))
130
+ else:
131
+ # fallback to raise error in dataloader
132
+ # if `batch_sampler_cfg` is not a valid type
133
+ batch_sampler = batch_sampler_cfg
134
+
135
+ # build dataloader
136
+ init_fn: Optional[partial]
137
+ if 'worker_init_fn' in dataloader_cfg:
138
+ worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
139
+ worker_init_fn_type = worker_init_fn_cfg.pop('type')
140
+ worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
141
+ assert callable(worker_init_fn)
142
+ init_fn = partial(worker_init_fn,
143
+ **worker_init_fn_cfg) # type: ignore
144
+ else:
145
+ if seed is not None:
146
+ disable_subprocess_warning = dataloader_cfg.pop(
147
+ 'disable_subprocess_warning', False)
148
+ assert isinstance(disable_subprocess_warning, bool), (
149
+ 'disable_subprocess_warning should be a bool, but got '
150
+ f'{type(disable_subprocess_warning)}')
151
+ init_fn = partial(
152
+ default_worker_init_fn,
153
+ num_workers=dataloader_cfg.get('num_workers'),
154
+ rank=get_rank(),
155
+ seed=seed,
156
+ disable_subprocess_warning=disable_subprocess_warning)
157
+ else:
158
+ init_fn = None
159
+
160
+ # `persistent_workers` requires pytorch version >= 1.7
161
+ if ('persistent_workers' in dataloader_cfg
162
+ and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
163
+ print_log(
164
+ '`persistent_workers` is only available when '
165
+ 'pytorch version >= 1.7',
166
+ logger='current',
167
+ level=logging.WARNING)
168
+ dataloader_cfg.pop('persistent_workers')
169
+
170
+ # The default behavior of `collat_fn` in dataloader is to
171
+ # merge a list of samples to form a mini-batch of Tensor(s).
172
+ # However, in mmengine, if `collate_fn` is not defined in
173
+ # dataloader_cfg, `pseudo_collate` will only convert the list of
174
+ # samples into a dict without stacking the batch tensor.
175
+ collate_fn_cfg = dataloader_cfg.pop('collate_fn',
176
+ dict(type='pseudo_collate'))
177
+ if isinstance(collate_fn_cfg, dict):
178
+ collate_fn_type = collate_fn_cfg.pop('type')
179
+ if isinstance(collate_fn_type, str):
180
+ collate_fn = FUNCTIONS.get(collate_fn_type)
181
+ elif inspect.isclass(collate_fn_type):
182
+ collate_fn_cfg['type'] = collate_fn_type
183
+ collate_fn = BUILDER.build(collate_fn_cfg)
184
+ else:
185
+ collate_fn = collate_fn_type
186
+ if not inspect.isclass(collate_fn_type):
187
+ collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
188
+ elif callable(collate_fn_cfg):
189
+ collate_fn = collate_fn_cfg
190
+ else:
191
+ raise TypeError(
192
+ 'collate_fn should be a dict or callable object, but got '
193
+ f'{collate_fn_cfg}')
194
+ data_loader = DataLoader(
195
+ dataset=dataset,
196
+ sampler=sampler if batch_sampler is None else None,
197
+ batch_sampler=batch_sampler,
198
+ collate_fn=collate_fn,
199
+ worker_init_fn=init_fn,
200
+ **dataloader_cfg)
201
+
202
+ return data_loader