File size: 3,752 Bytes
7c8069d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
from pathlib import Path, PosixPath

import torch
from PIL import Image
from torchvision import transforms


def save_checkpoint(args, accelerator, global_step, logger):
    output_dir = args.output_dir

    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
    if accelerator.is_main_process and args.checkpoints_total_limit is not None:
        checkpoints = os.listdir(output_dir)
        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
        if len(checkpoints) >= args.checkpoints_total_limit:
            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
            removing_checkpoints = checkpoints[0:num_to_remove]

            logger.info(
                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
            )
            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

            for removing_checkpoint in removing_checkpoints:
                removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
                shutil.rmtree(removing_checkpoint)

    save_path = Path(output_dir) / f"checkpoint-{global_step}"
    accelerator.save_state(save_path)
    logger.info(f"Saved state to {save_path}")


def load_images_to_tensor(path, target_size=(1024, 1024)):
    """
    Args:
        folder_path
        target_size: (height, width)
    
    Return:
        torch.Tensor: [B, 3, H, W] in [0, 1]
    """
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')
    
    if isinstance(path, list):
        image_files = path
    elif isinstance(path, str) and os.path.isdir(path):
        image_files = [f for f in os.listdir(path) if f.lower().endswith(valid_extensions)]
    elif isinstance(path, str):
        image_files = [path]
    else:
        raise ValueError(f"Unsupported folder_path type: {type(path)}")
    
    if not image_files:
        raise ValueError(f"No valid images found in {path}")
    
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ])
    
    tensors = []
    for img_file in image_files:
        try:
            if isinstance(path, str) and os.path.isdir(path):
                img_path = os.path.join(path, img_file)
            else:
                img_path = img_file
            img = Image.open(img_path).convert('RGB')
            tensor = transform(img)
            tensors.append(tensor)
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
    
    if not tensors:
        raise ValueError("No images could be loaded")
    
    batch_tensor = torch.stack(tensors)
    
    assert batch_tensor.shape[1:] == (3, *target_size), \
        f"Output shape is {batch_tensor.shape}, expected (B, 3, {target_size[0]}, {target_size[1]})"
    
    return batch_tensor