Spaces:
Build error
Build error
File size: 3,646 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import IO, Optional, Tuple, Union
import numpy as np
from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler
try:
from PIL import Image
except ImportError:
Image = None
class PILHandler(BaseFileHandler):
format: str
str_like = False
def load_from_fileobj(
self,
file: IO[bytes],
fmt: str = "pil",
size: Optional[Union[int, Tuple[int, int]]] = None,
**kwargs,
):
"""
Load an image from a file-like object and return it in a specified format.
Args:
file (IO[bytes]): A file-like object containing the image data.
fmt (str): The format to convert the image into. Options are \
'numpy', 'np', 'npy', 'type' (all return numpy arrays), \
'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor).
size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \
or a tuple of (width, height). If specified, the image is resized accordingly.
**kwargs: Additional keyword arguments that can be passed to conversion functions.
Returns:
Image data in the format specified by `fmt`.
Raises:
IOError: If the image cannot be loaded or processed.
ValueError: If the specified format is unsupported.
"""
try:
img = Image.open(file)
img.load() # Explicitly load the image data
if size is not None:
if isinstance(size, int):
size = (
size,
size,
) # create a tuple if only one integer is provided
img = img.resize(size, Image.ANTIALIAS)
# Return the image in the requested format
if fmt in ["numpy", "np", "npy"]:
return np.array(img, **kwargs)
if fmt == "pil":
return img
if fmt in ["th", "torch"]:
import torch
# Convert to tensor
img_tensor = torch.from_numpy(np.array(img, **kwargs))
# Convert image from HxWxC to CxHxW
if img_tensor.ndim == 3:
img_tensor = img_tensor.permute(2, 0, 1)
return img_tensor
raise ValueError(
"Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'."
)
except Exception as e:
raise IOError(f"Unable to load image: {e}") from e
def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs):
if "format" not in kwargs:
kwargs["format"] = self.format
kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper()
obj.save(file, **kwargs)
def dump_to_str(self, obj, **kwargs):
raise NotImplementedError
|