Spaces:
Running
on
A100
Running
on
A100
File size: 4,324 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
import argparse
import csv
import itertools
import json
import os
import torch
from datasets import load_dataset
from tqdm import tqdm
import llava
from llava import conversation as conversation_lib
from llava.data.builder import DATASETS
from llava.eval.mmmu_utils.eval_utils import parse_choice
from llava.utils import distributed as dist
from llava.utils import io
from llava.utils.logging import logger
def load_existing_ids(output_file):
if not os.path.exists(output_file):
return set(), []
try:
with open(output_file, "r") as f:
lines = f.readlines()
outputs = [json.loads(line) for line in lines]
processed_ids = {item["id"] for item in outputs}
return processed_ids, outputs
except Exception as e:
print(f"Error loading existing outputs: {e}")
return set(), []
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default="auto")
parser.add_argument("--generation-config", type=json.loads)
parser.add_argument("--output-dir", type=str, default=None)
args = parser.parse_args()
# Set up distributed environment
dist.init()
devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
torch.cuda.set_device(devices[0])
# Load stage 3 model with line 56
model = llava.load(args.model_base, model_base=None, devices=devices)
# Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode
# model = PeftModel.from_pretrained(
# model,
# args.model_path,
# device_map="auto",
# torch_dtype=torch.float16,
# )
# Set up generation config
generation_config = model.default_generation_config
if args.generation_config is not None:
generation_config.update(**args.generation_config)
# Load data and chunk it
json_file = DATASETS[args.task]["data_path"]
instances = io.load(json_file)
instances = instances[dist.rank() :: dist.size()]
output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl")
processed_ids, outputs = load_existing_ids(output_path)
count = len(outputs)
# Run inference
new_outputs = []
for instance in tqdm(instances, disable=not dist.is_main()):
uuid = instance["id"]
sound_path = instance["sound"]
if sound_path in processed_ids:
continue # Skip if already processed
sound = llava.Sound(sound_path)
conversations = instance["conversations"]
question = conversations[0]["value"]
response = model.generate_content([sound, question], generation_config=generation_config)
print("response", response)
output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response}
new_outputs.append(output)
count = count +1
if count % 20 == 0:
# Gather and save outputs
if dist.size() > 1:
outputs_new = dist.gather(new_outputs, dst=0)
if dist.is_main():
outputs_new = list(itertools.chain(*outputs_new))
final_outputs = outputs + outputs_new
io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
else:
final_outputs = outputs + new_outputs
io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
if dist.size() > 1:
new_outputs = dist.gather(new_outputs, dst=0)
if not dist.is_main():
return
new_outputs = list(itertools.chain(*new_outputs))
final_outputs = outputs + new_outputs
io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs)
if __name__ == "__main__":
main()
|