| """ | |
| This file is used to extract feature of the empty prompt. | |
| """ | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import torch | |
| import os | |
| import numpy as np | |
| from libs.clip import FrozenCLIPEmbedder | |
| from libs.t5 import T5Embedder | |
| def main(): | |
| prompts = [ | |
| '', | |
| ] | |
| device = 'cuda' | |
| llm = 'clip' | |
| if llm=='clip': | |
| clip = FrozenCLIPEmbedder() | |
| clip.eval() | |
| clip.to(device) | |
| elif llm=='t5': | |
| t5 = T5Embedder(device=device) | |
| else: | |
| raise NotImplementedError | |
| save_dir = f'./' | |
| if llm=='clip': | |
| latent, latent_and_others = clip.encode(prompts) | |
| token_embedding = latent_and_others['token_embedding'] | |
| token_mask = latent_and_others['token_mask'] | |
| token = latent_and_others['tokens'] | |
| elif llm=='t5': | |
| latent, latent_and_others = t5.get_text_embeddings(prompts) | |
| token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0 | |
| token_mask = latent_and_others['token_mask'] | |
| token = latent_and_others['tokens'] | |
| for i in range(len(prompts)): | |
| data = {'token_embedding': token_embedding[i].detach().cpu().numpy(), | |
| 'token_mask': token_mask[i].detach().cpu().numpy(), | |
| 'token': token[i].detach().cpu().numpy(), | |
| 'batch_caption': prompts[i]} | |
| np.save(os.path.join(save_dir, f'empty_context.npy'), data) | |
| if __name__ == '__main__': | |
| main() | |