File size: 2,538 Bytes
b14067d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
import argparse
import torch
import os
import tqdm
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer

def compute_prompt_embeddings(tokenizer, text_encoder, prompts, max_sequence_length=226, device=torch.device("cpu"), dtype=torch.float16):
    if isinstance(prompts, str):
        prompts = [prompts]

    text_inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt"
    )
    text_input_ids = text_inputs.input_ids.to(device)

    with torch.no_grad():
        prompt_embeds = text_encoder(text_input_ids)[0]
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    return prompt_embeds

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None
    )
    text_encoder = T5EncoderModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None
    ).to(device)

    all_files = sorted(os.listdir(args.caption_path))
    chunk = all_files[args.start_idx: args.end_idx]

    os.makedirs(args.output_path, exist_ok=True)

    for name in tqdm.tqdm(chunk, desc=f"GPU {args.gpu_id}"):
        with open(os.path.join(args.caption_path, name), 'r') as f:
            caption = f.read().strip()

        embeddings = compute_prompt_embeddings(
            tokenizer,
            text_encoder,
            caption,
            max_sequence_length=args.max_sequence_length,
            device=device,
            dtype=torch.bfloat16
        ).cpu()
        torch.save(embeddings, os.path.join(args.output_path, name.replace('.txt', '') + '.pt'))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Single-GPU T5 prompt embedding")
    parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
    parser.add_argument("--caption_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--max_sequence_length", type=int, default=226)
    parser.add_argument("--gpu_id", type=int, required=True)
    parser.add_argument("--start_idx", type=int, required=True)
    parser.add_argument("--end_idx", type=int, required=True)
    args = parser.parse_args()
    
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    main(args)