File size: 4,324 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

import argparse
import csv
import itertools
import json
import os

import torch
from datasets import load_dataset
from tqdm import tqdm

import llava
from llava import conversation as conversation_lib
from llava.data.builder import DATASETS
from llava.eval.mmmu_utils.eval_utils import parse_choice
from llava.utils import distributed as dist
from llava.utils import io
from llava.utils.logging import logger


def load_existing_ids(output_file):
    if not os.path.exists(output_file):
        return set(), []
    try:
        with open(output_file, "r") as f:
            lines = f.readlines()
            outputs = [json.loads(line) for line in lines]
            processed_ids = {item["id"] for item in outputs}
            return processed_ids, outputs
    except Exception as e:
        print(f"Error loading existing outputs: {e}")
        return set(), []


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default=None)
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--task", type=str, default=None)
    parser.add_argument("--conv-mode", type=str, default="auto")
    parser.add_argument("--generation-config", type=json.loads)
    parser.add_argument("--output-dir", type=str, default=None)
    args = parser.parse_args()

    # Set up distributed environment
    dist.init()
    devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
    torch.cuda.set_device(devices[0])

    # Load stage 3 model with line 56
    model = llava.load(args.model_base, model_base=None, devices=devices)
    # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode
    # model = PeftModel.from_pretrained(
    #         model,
    #         args.model_path,
    #         device_map="auto",
    #         torch_dtype=torch.float16,
    #     )
    # Set up generation config
    generation_config = model.default_generation_config
    if args.generation_config is not None:
        generation_config.update(**args.generation_config)

    # Load data and chunk it
    json_file = DATASETS[args.task]["data_path"]
    instances = io.load(json_file)
    instances = instances[dist.rank() :: dist.size()]

    output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl")
    processed_ids, outputs = load_existing_ids(output_path)

    count = len(outputs)
    # Run inference
    new_outputs = []
    for instance in tqdm(instances, disable=not dist.is_main()):
        uuid = instance["id"]
        sound_path = instance["sound"]

        if sound_path in processed_ids:
            continue  # Skip if already processed
        sound = llava.Sound(sound_path)
        conversations = instance["conversations"]
        question = conversations[0]["value"]

        response = model.generate_content([sound, question], generation_config=generation_config)
       
        print("response", response)

        output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"],  "pred": response}
        new_outputs.append(output)
        count = count +1
        if count % 20 == 0:
            # Gather and save outputs
            if dist.size() > 1:
                outputs_new = dist.gather(new_outputs, dst=0)
                if dist.is_main():
                    outputs_new = list(itertools.chain(*outputs_new))
                    final_outputs = outputs + outputs_new
                    io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
            else:
                final_outputs = outputs + new_outputs
                io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
    if dist.size() > 1:
        new_outputs = dist.gather(new_outputs, dst=0)
        if not dist.is_main():
            return
        new_outputs = list(itertools.chain(*new_outputs))
        final_outputs = outputs + new_outputs
    io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs)

if __name__ == "__main__":
    main()