File size: 4,243 Bytes
735672d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import argparse
import PIL
from PIL import Image
import os
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from conversation import conv_templates, SeparatorStyle
from torchvision import transforms

from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

from threading import Thread
from unitok.config import Args
from unitok.model import UniTok

from model.builder import load_pretrained_model
from mm_utils import tokenizer_image_token, get_model_name_from_path


IMAGE_TOKEN_INDEX=-200

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result
    

def main(args):
    
    ckpt = torch.load(args.unitok_path, map_location='cpu')
    vae_cfg = Args()
    vae_cfg.load_state_dict(ckpt['args'])
    vq_model = UniTok(vae_cfg)
    vq_model.load_state_dict(ckpt['trainer']['unitok'])
    vq_model.to('cuda')
    vq_model.eval()

    model_path = os.path.expanduser(args.mllm_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, vqllm, image_processor, context_len = load_pretrained_model(model_path, model_name, load_8bit=args.load_8bit)

    qs = args.prompt
    qs = '<boi><image><eoi>' + '\n' + qs
    conv = conv_templates['llava_v1'].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    crop_size = 256
    transform = transforms.Compose([
        transforms.Resize((crop_size, crop_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    
    print(prompt)
    image = Image.open(args.image_path).convert('RGB')
    pad_image = expand2square(image, (122, 116, 104) )
    # import pdb;pdb.set_trace()
    img = transform(pad_image).unsqueeze(0)
    img = img.to('cuda')
    # import pdb;pdb.set_trace()
    with torch.no_grad():
        vq_code = vq_model.img_to_idx(img)
        image_codes = vq_code.unsqueeze(0)
        
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        
        # input_ids = torch.cat(text_ids, dim=0)
        # input_embeddings = vqllm.embed_tokens(input_ids)
        inputs =  {
            "inputs":input_ids.unsqueeze(0).to("cuda:0"),
            "images":image_codes.to("cuda:0"),
            "max_new_tokens":1024,
            "bos_token_id":tokenizer.bos_token_id,  # Begin of sequence token
            "eos_token_id":tokenizer.eos_token_id,  # End of sequence token
            "pad_token_id":tokenizer.pad_token_id,  # Pad token
            }
        streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True})

        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
        thread = Thread(target=vqllm.generate_mllm, kwargs=generation_kwargs)
        thread.start()
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
        print(generated_text)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--unitok_path', type=str, default=r'D:\projects\liquid_app\UniTok\UniTok_weights\unitok_tokenizer\unitok_tokenizer.pth',required=False)
    parser.add_argument('--mllm_path', type=str, default= r'C:\debug_ckpts\unitok_mllm', required=False)
    parser.add_argument('--prompt', type=str, required=True, help='input text prompt')
    parser.add_argument('--image_path', type=str, required=True, help='input image path')
    parser.add_argument('--load_8bit',  action='store_true', default=False, help='use 8bit to save memory')

    args = parser.parse_args()
    main(args)