Spaces:
Running
on
L40S
Running
on
L40S
File size: 10,224 Bytes
fc0a183 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import os
from pathlib import Path
import argparse
import glob
import time
import gc
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
import pandas as pd
from vllm import LLM, SamplingParams
from torch.utils.data import DataLoader
import json
import random
from utils import result_writer
SYSTEM_PROMPT_I2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
## Structured Input
{structured_input}
## Notes
1. If there has an empty field, just ignore it and do not mention it in the output.
2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.
## Output Principles and Orders
1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
2. Second, describe each subject with its pure action and expression if these fields exist.
## Output
Please directly output the final composed caption without any additional information.
"""
SYSTEM_PROMPT_T2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
## Structured Input
{structured_input}
## Notes
1. According to the action field information, change its name field to the subject pronoun in the action.
2. If there has an empty field, just ignore it and do not mention it in the output.
3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
## Output Principles and Orders
1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.
## Output
Please directly output the final composed caption without any additional information.
"""
SHOT_TYPE_LIST = [
'close-up shot',
'extreme close-up shot',
'medium shot',
'long shot',
'full shot',
]
class StructuralCaptionDataset(torch.utils.data.Dataset):
def __init__(self, input_csv, model_path):
self.meta = pd.read_csv(input_csv)
self.task = args.task
self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
def __len__(self):
return len(self.meta)
def __getitem__(self, index):
row = self.meta.iloc[index]
real_index = self.meta.index[index]
struct_caption = json.loads(row["structural_caption"])
camera_movement = struct_caption.get('camera_motion', '')
if camera_movement != '':
camera_movement += '.'
camera_movement = camera_movement.capitalize()
fusion_by_llm = False
cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
if cleaned_struct_caption.get('num_subjects', 0) > 0:
new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
conversation = [
{
"role": "system",
"content": self.system_prompt.format(structured_input=new_struct_caption),
},
]
text = self.tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
fusion_by_llm = True
else:
text = '-'
return real_index, fusion_by_llm, text, '-', camera_movement
def clean_struct_caption(self, struct_caption, task):
raw_subjects = struct_caption.get('subjects', [])
subjects = []
for subject in raw_subjects:
subject_type = subject.get("TYPES", {}).get('type', '')
subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
if subject_type not in ["Human", "Animal"]:
subject['expression'] = ''
if subject_type == 'Human' and subject_sub_type == 'Accessory':
subject['expression'] = ''
if subject_sub_type != '':
subject['name'] = subject_sub_type
if 'TYPES' in subject:
del subject['TYPES']
if 'is_main_subject' in subject:
del subject['is_main_subject']
subjects.append(subject)
to_del_subject_ids = []
for idx, subject in enumerate(subjects):
action = subject.get('action', '').strip()
subject['action'] = action
if random.random() > 0.9 and 'appearance' in subject:
del subject['appearance']
if random.random() > 0.9 and 'position' in subject:
del subject['position']
if task == 'i2v':
# just keep name and action, expression in subjects
dropped_keys = ['appearance', 'position']
for key in dropped_keys:
if key in subject:
del subject[key]
if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
to_del_subject_ids.append(idx)
# delete the subjects according to the to_del_subject_ids
for idx in sorted(to_del_subject_ids, reverse=True):
del subjects[idx]
shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
if shot_type not in SHOT_TYPE_LIST:
struct_caption['shot_type'] = ''
new_struct_caption = {
'num_subjects': len(subjects),
'subjects': subjects,
'shot_type': struct_caption.get('shot_type', ''),
'shot_angle': struct_caption.get('shot_angle', ''),
'shot_position': struct_caption.get('shot_position', ''),
'environment': struct_caption.get('environment', ''),
'lighting': struct_caption.get('lighting', ''),
}
if task == 't2v' and random.random() > 0.9:
del new_struct_caption['lighting']
if task == 'i2v':
drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
for drop_key in drop_keys:
del new_struct_caption[drop_key]
return new_struct_caption
def custom_collate_fn(batch):
real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
parser.add_argument("--input_csv", default="./examples/test_result.csv")
parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
parser.add_argument("--bs", type=int, default=4)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
parser.add_argument("--task", default='t2v', help="t2v or i2v")
args = parser.parse_args()
sampling_params = SamplingParams(
temperature=0.1,
max_tokens=512,
stop=['\n\n']
)
# model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"
llm = LLM(
model=args.model_path,
gpu_memory_utilization=0.9,
max_model_len=4096,
tensor_parallel_size = args.tp
)
dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=8,
collate_fn=custom_collate_fn,
shuffle=False,
drop_last=False,
)
indices_list = []
result_list = []
for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
if fusion_by_llm:
llm_indices.append(idx)
llm_texts.append(text)
llm_original_texts.append(original_text)
llm_camera_movements.append(camera_movement)
else:
indices_list.append(idx)
caption = original_text + " " + camera_movement
result_list.append(caption)
if len(llm_texts) > 0:
try:
outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
results = []
for output in outputs:
result = output.outputs[0].text.strip()
results.append(result)
indices_list.extend(llm_indices)
except Exception as e:
print(f"Error at {llm_indices}: {str(e)}")
indices_list.extend(llm_indices)
results = llm_original_texts
for result, camera_movement in zip(results, llm_camera_movements):
# concat camera movement to fusion_caption
llm_caption = result + " " + camera_movement
result_list.append(llm_caption)
torch.cuda.empty_cache()
gc.collect()
gathered_list = [indices_list, result_list]
meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
meta_new.to_csv(args.out_csv, index=False)
|