audio-flamingo-3 / llava /cli /infer_audio.py
SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import argparse
import importlib.util
import json
import os
from pydantic import BaseModel
from termcolor import colored
import llava
from llava import conversation as clib
from llava.media import Image, Video, Sound
from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat
from peft import PeftModel
import torch
def get_schema_from_python_path(path: str) -> str:
schema_path = os.path.abspath(path)
spec = importlib.util.spec_from_file_location("schema_module", schema_path)
schema_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(schema_module)
# Get the Main class from the loaded module
Main = schema_module.Main
assert issubclass(
Main, BaseModel
), f"The provided python file {path} does not contain a class Main that describes a JSON schema"
return Main.schema_json()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-base", "-mb", type=str, required=True)
parser.add_argument("--model-path", "-mp", type=str, required=True)
parser.add_argument("--conv-mode", "-c", type=str, default="auto")
parser.add_argument("--text", type=str)
parser.add_argument("--media", type=str, nargs="+")
parser.add_argument("--json-mode", action="store_true")
parser.add_argument("--peft-mode", action="store_true")
parser.add_argument("--json-schema", type=str, default=None)
args = parser.parse_args()
# Convert json mode to response format
if not args.json_mode:
response_format = None
elif args.json_schema is None:
response_format = ResponseFormat(type="json_object")
else:
schema_str = get_schema_from_python_path(args.json_schema)
print(schema_str)
response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str))
# Load model
model = llava.load(args.model_base)
if args.peft_mode:
model = PeftModel.from_pretrained(
model,
args.model_path,
device_map="auto",
torch_dtype=torch.float16,
)
# Set conversation mode
clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
# Prepare multi-modal prompt
prompt = []
if args.media is not None:
for media in args.media or []:
if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]):
media = Sound(media)
else:
raise ValueError(f"Unsupported media type: {media}")
prompt.append(media)
if args.text is not None:
prompt.append(args.text)
# Generate response
response = model.generate_content(prompt, response_format=response_format)
print(colored(response, "cyan", attrs=["bold"]))
if __name__ == "__main__":
main()