# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. import argparse import importlib.util import json import os from pydantic import BaseModel from termcolor import colored import llava from llava import conversation as clib from llava.media import Image, Video, Sound from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat from peft import PeftModel import torch def get_schema_from_python_path(path: str) -> str: schema_path = os.path.abspath(path) spec = importlib.util.spec_from_file_location("schema_module", schema_path) schema_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(schema_module) # Get the Main class from the loaded module Main = schema_module.Main assert issubclass( Main, BaseModel ), f"The provided python file {path} does not contain a class Main that describes a JSON schema" return Main.schema_json() def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model-base", "-mb", type=str, required=True) parser.add_argument("--model-path", "-mp", type=str, required=True) parser.add_argument("--conv-mode", "-c", type=str, default="auto") parser.add_argument("--text", type=str) parser.add_argument("--media", type=str, nargs="+") parser.add_argument("--json-mode", action="store_true") parser.add_argument("--peft-mode", action="store_true") parser.add_argument("--json-schema", type=str, default=None) args = parser.parse_args() # Convert json mode to response format if not args.json_mode: response_format = None elif args.json_schema is None: response_format = ResponseFormat(type="json_object") else: schema_str = get_schema_from_python_path(args.json_schema) print(schema_str) response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str)) # Load model model = llava.load(args.model_base) if args.peft_mode: model = PeftModel.from_pretrained( model, args.model_path, device_map="auto", torch_dtype=torch.float16, ) # Set conversation mode clib.default_conversation = clib.conv_templates[args.conv_mode].copy() # Prepare multi-modal prompt prompt = [] if args.media is not None: for media in args.media or []: if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]): media = Sound(media) else: raise ValueError(f"Unsupported media type: {media}") prompt.append(media) if args.text is not None: prompt.append(args.text) # Generate response response = model.generate_content(prompt, response_format=response_format) print(colored(response, "cyan", attrs=["bold"])) if __name__ == "__main__": main()