Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import argparse | |
import importlib.util | |
import json | |
import os | |
from pydantic import BaseModel | |
from termcolor import colored | |
import llava | |
from llava import conversation as clib | |
from llava.media import Image, Video, Sound | |
from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat | |
from peft import PeftModel | |
import torch | |
def get_schema_from_python_path(path: str) -> str: | |
schema_path = os.path.abspath(path) | |
spec = importlib.util.spec_from_file_location("schema_module", schema_path) | |
schema_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(schema_module) | |
# Get the Main class from the loaded module | |
Main = schema_module.Main | |
assert issubclass( | |
Main, BaseModel | |
), f"The provided python file {path} does not contain a class Main that describes a JSON schema" | |
return Main.schema_json() | |
def main() -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-base", "-mb", type=str, required=True) | |
parser.add_argument("--model-path", "-mp", type=str, required=True) | |
parser.add_argument("--conv-mode", "-c", type=str, default="auto") | |
parser.add_argument("--text", type=str) | |
parser.add_argument("--media", type=str, nargs="+") | |
parser.add_argument("--json-mode", action="store_true") | |
parser.add_argument("--peft-mode", action="store_true") | |
parser.add_argument("--json-schema", type=str, default=None) | |
args = parser.parse_args() | |
# Convert json mode to response format | |
if not args.json_mode: | |
response_format = None | |
elif args.json_schema is None: | |
response_format = ResponseFormat(type="json_object") | |
else: | |
schema_str = get_schema_from_python_path(args.json_schema) | |
print(schema_str) | |
response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str)) | |
# Load model | |
model = llava.load(args.model_base) | |
if args.peft_mode: | |
model = PeftModel.from_pretrained( | |
model, | |
args.model_path, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
# Set conversation mode | |
clib.default_conversation = clib.conv_templates[args.conv_mode].copy() | |
# Prepare multi-modal prompt | |
prompt = [] | |
if args.media is not None: | |
for media in args.media or []: | |
if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]): | |
media = Sound(media) | |
else: | |
raise ValueError(f"Unsupported media type: {media}") | |
prompt.append(media) | |
if args.text is not None: | |
prompt.append(args.text) | |
# Generate response | |
response = model.generate_content(prompt, response_format=response_format) | |
print(colored(response, "cyan", attrs=["bold"])) | |
if __name__ == "__main__": | |
main() | |