roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A CLI to run ImageTokenizer on plain images based on torch.jit.
Usage:
python3 -m cosmos_predict1.tokenizer.inference.image_cli \
--image_pattern 'path/to/input/folder/*.jpg' \
--output_dir ./reconstructions \
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
Optionally, you can run the model in pure PyTorch mode:
python3 -m cosmos_predict1.tokenizer.inference.image_cli \
--image_pattern 'path/to/input/folder/*.jpg' \
--mode torch \
--tokenizer_type CI8x8 \
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
"""
import os
import sys
from argparse import ArgumentParser, Namespace
from typing import Any
import numpy as np
from loguru import logger as logging
from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer
from cosmos_predict1.tokenizer.inference.utils import (
get_filepaths,
get_output_filepath,
read_image,
resize_image,
write_image,
)
from cosmos_predict1.tokenizer.networks import TokenizerConfigs
def _parse_args() -> tuple[Namespace, dict[str, Any]]:
parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.")
parser.add_argument(
"--image_pattern",
type=str,
default="path/to/images/*.jpg",
help="Glob pattern.",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="JIT full Autoencoder model filepath.",
)
parser.add_argument(
"--checkpoint_enc",
type=str,
default=None,
help="JIT Encoder model filepath.",
)
parser.add_argument(
"--checkpoint_dec",
type=str,
default=None,
help="JIT Decoder model filepath.",
)
parser.add_argument(
"--tokenizer_type",
type=str,
default=None,
choices=[
"CI8x8-360p",
"CI16x16-360p",
"DI8x8-360p",
"DI16x16-360p",
],
help="Specifies the tokenizer type.",
)
parser.add_argument(
"--mode",
type=str,
choices=["torch", "jit"],
default="jit",
help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
)
parser.add_argument(
"--short_size",
type=int,
default=None,
help="The size to resample inputs. None, by default.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="Sets the precision. Default bfloat16.",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device for invoking the model.",
)
parser.add_argument("--output_dir", type=str, default=None, help="Output directory.")
parser.add_argument(
"--save_input",
action="store_true",
help="If on, the input image will be be outputed too.",
)
args = parser.parse_args()
return args
logging.info("Initializes args ...")
args = _parse_args()
if args.mode == "torch" and args.tokenizer_type is None:
logging.error("'torch' backend requires the tokenizer_type to be specified.")
sys.exit(1)
def _run_eval() -> None:
"""Invokes the evaluation pipeline."""
if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None:
logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.")
return
if args.mode == "torch":
_type = args.tokenizer_type.replace("-", "_")
_config = TokenizerConfigs[_type].value
else:
_config = None
logging.info(
f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
)
autoencoder = ImageTokenizer(
checkpoint=args.checkpoint,
checkpoint_enc=args.checkpoint_enc,
checkpoint_dec=args.checkpoint_dec,
tokenizer_config=_config,
device=args.device,
dtype=args.dtype,
)
filepaths = get_filepaths(args.image_pattern)
logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.")
for filepath in filepaths:
logging.info(f"Reading image {filepath} ...")
image = read_image(filepath)
image = resize_image(image, short_size=args.short_size)
batch_image = np.expand_dims(image, axis=0)
logging.info("Invoking the autoencoder model in ... ")
output_image = autoencoder(batch_image)[0]
output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
logging.info(f"Outputing {output_filepath} ...")
write_image(output_filepath, output_image)
if args.save_input:
ext = os.path.splitext(output_filepath)[-1]
input_filepath = output_filepath.replace(ext, "_input" + ext)
write_image(input_filepath, image)
@logging.catch(reraise=True)
def main() -> None:
_run_eval()
if __name__ == "__main__":
main()