|
|
|
import argparse |
|
import torch |
|
import os |
|
import tqdm |
|
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer |
|
|
|
def compute_prompt_embeddings(tokenizer, text_encoder, prompts, max_sequence_length=226, device=torch.device("cpu"), dtype=torch.float16): |
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
|
|
text_inputs = tokenizer( |
|
prompts, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
add_special_tokens=True, |
|
return_tensors="pt" |
|
) |
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
|
with torch.no_grad(): |
|
prompt_embeds = text_encoder(text_input_ids)[0] |
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
return prompt_embeds |
|
|
|
def main(args): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None |
|
) |
|
text_encoder = T5EncoderModel.from_pretrained( |
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None |
|
).to(device) |
|
|
|
all_files = sorted(os.listdir(args.caption_path)) |
|
chunk = all_files[args.start_idx: args.end_idx] |
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
for name in tqdm.tqdm(chunk, desc=f"GPU {args.gpu_id}"): |
|
with open(os.path.join(args.caption_path, name), 'r') as f: |
|
caption = f.read().strip() |
|
|
|
embeddings = compute_prompt_embeddings( |
|
tokenizer, |
|
text_encoder, |
|
caption, |
|
max_sequence_length=args.max_sequence_length, |
|
device=device, |
|
dtype=torch.bfloat16 |
|
).cpu() |
|
torch.save(embeddings, os.path.join(args.output_path, name.replace('.txt', '') + '.pt')) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Single-GPU T5 prompt embedding") |
|
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) |
|
parser.add_argument("--caption_path", type=str, required=True) |
|
parser.add_argument("--output_path", type=str, required=True) |
|
parser.add_argument("--max_sequence_length", type=int, default=226) |
|
parser.add_argument("--gpu_id", type=int, required=True) |
|
parser.add_argument("--start_idx", type=int, required=True) |
|
parser.add_argument("--end_idx", type=int, required=True) |
|
args = parser.parse_args() |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) |
|
main(args) |
|
|