# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A CLI to run ImageTokenizer on plain images based on torch.jit. Usage: python3 -m cosmos_predict1.tokenizer.inference.image_cli \ --image_pattern 'path/to/input/folder/*.jpg' \ --output_dir ./reconstructions \ --checkpoint_enc ./checkpoints//encoder.jit \ --checkpoint_dec ./checkpoints//decoder.jit Optionally, you can run the model in pure PyTorch mode: python3 -m cosmos_predict1.tokenizer.inference.image_cli \ --image_pattern 'path/to/input/folder/*.jpg' \ --mode torch \ --tokenizer_type CI8x8 \ --checkpoint_enc ./checkpoints//encoder.jit \ --checkpoint_dec ./checkpoints//decoder.jit """ import os import sys from argparse import ArgumentParser, Namespace from typing import Any import numpy as np from loguru import logger as logging from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer from cosmos_predict1.tokenizer.inference.utils import ( get_filepaths, get_output_filepath, read_image, resize_image, write_image, ) from cosmos_predict1.tokenizer.networks import TokenizerConfigs def _parse_args() -> tuple[Namespace, dict[str, Any]]: parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.") parser.add_argument( "--image_pattern", type=str, default="path/to/images/*.jpg", help="Glob pattern.", ) parser.add_argument( "--checkpoint", type=str, default=None, help="JIT full Autoencoder model filepath.", ) parser.add_argument( "--checkpoint_enc", type=str, default=None, help="JIT Encoder model filepath.", ) parser.add_argument( "--checkpoint_dec", type=str, default=None, help="JIT Decoder model filepath.", ) parser.add_argument( "--tokenizer_type", type=str, default=None, choices=[ "CI8x8-360p", "CI16x16-360p", "DI8x8-360p", "DI16x16-360p", ], help="Specifies the tokenizer type.", ) parser.add_argument( "--mode", type=str, choices=["torch", "jit"], default="jit", help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", ) parser.add_argument( "--short_size", type=int, default=None, help="The size to resample inputs. None, by default.", ) parser.add_argument( "--dtype", type=str, default="bfloat16", help="Sets the precision. Default bfloat16.", ) parser.add_argument( "--device", type=str, default="cuda", help="Device for invoking the model.", ) parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") parser.add_argument( "--save_input", action="store_true", help="If on, the input image will be be outputed too.", ) args = parser.parse_args() return args logging.info("Initializes args ...") args = _parse_args() if args.mode == "torch" and args.tokenizer_type is None: logging.error("'torch' backend requires the tokenizer_type to be specified.") sys.exit(1) def _run_eval() -> None: """Invokes the evaluation pipeline.""" if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") return if args.mode == "torch": _type = args.tokenizer_type.replace("-", "_") _config = TokenizerConfigs[_type].value else: _config = None logging.info( f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." ) autoencoder = ImageTokenizer( checkpoint=args.checkpoint, checkpoint_enc=args.checkpoint_enc, checkpoint_dec=args.checkpoint_dec, tokenizer_config=_config, device=args.device, dtype=args.dtype, ) filepaths = get_filepaths(args.image_pattern) logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.") for filepath in filepaths: logging.info(f"Reading image {filepath} ...") image = read_image(filepath) image = resize_image(image, short_size=args.short_size) batch_image = np.expand_dims(image, axis=0) logging.info("Invoking the autoencoder model in ... ") output_image = autoencoder(batch_image)[0] output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) logging.info(f"Outputing {output_filepath} ...") write_image(output_filepath, output_image) if args.save_input: ext = os.path.splitext(output_filepath)[-1] input_filepath = output_filepath.replace(ext, "_input" + ext) write_image(input_filepath, image) @logging.catch(reraise=True) def main() -> None: _run_eval() if __name__ == "__main__": main()