File size: 5,093 Bytes
ed1cd13
 
41e6903
 
 
 
 
854f0cf
 
31d8777
 
a76b117
 
 
 
41e6903
854f0cf
41e6903
 
 
 
a76b117
3ac1ccb
bc91b52
31d8777
 
41e6903
 
b083d4d
41e6903
b083d4d
41e6903
 
ed1cd13
 
b083d4d
ed1cd13
 
 
b083d4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8fd25d
 
 
b083d4d
 
 
f8fd25d
b083d4d
 
31d8777
 
 
 
 
 
 
 
f8fd25d
 
b083d4d
 
 
 
 
032b71e
 
 
31d8777
 
b083d4d
 
 
 
 
 
 
 
 
 
032b71e
b083d4d
 
854f0cf
b083d4d
 
 
 
 
 
 
 
 
 
dacd4b7
41e6903
 
 
 
 
ed1cd13
 
41e6903
854f0cf
41e6903
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from copy import deepcopy

import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig

from sentence_transformers import SentenceTransformer, util

from transformers import PretrainedConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

embedder = SentenceTransformer('all-mpnet-base-v2')
model_id = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
    # use_flash_attention_2=True,
    low_cpu_mem_usage=True,
    # config=PretrainedConfig(do_sample=True)
)

MAXIMUM_PIXEL_VALUES = 3725568

def text_to_image(image, prompt, duplications: float):
    prompt = f'USER: <image>\n{prompt}\nASSISTANT:'

    image_batch = [image]
    prompt_batch = [prompt]
    for _ in range(int(duplications)):
        image_batch.append(deepcopy(image))
        prompt_batch.append(prompt)

    inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")

    batched_inputs :list[dict[str, torch.Tensor]] = list()
    if inputs['pixel_values'].flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
        batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
        i = 0
        while i < len(inputs['pixel_values']):
            batch['input_ids'].append(inputs['input_ids'][i])
            batch['attention_mask'].append(inputs['attention_mask'][i])
            batch['pixel_values'].append(inputs['pixel_values'][i])

            if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
                print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
                      f'16GB GPU. Will split in more batches')
                # Remove the last added image because it's too big to process
                batch['input_ids'].pop()
                batch['attention_mask'].pop()
                batch['pixel_values'].pop()

                # transform lists to tensors
                batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
                batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
                batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)

                # Add to the batched_inputs
                batched_inputs.append(batch)
                batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
            else:
                i += 1
        if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
            batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
            batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
            batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)

            # Add to the batched_inputs
            batched_inputs.append(batch)
            batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
    else:
        batched_inputs.append(inputs)

    maurice_description = list()
    maurice_embeddings = list()
    for batch in batched_inputs:
        # Load on device
        batch['input_ids'] = batch['input_ids'].to(model.device)
        batch['attention_mask'] = batch['attention_mask'].to(model.device)
        batch['pixel_values'] = batch['pixel_values'].to(model.device)
        # output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
        output = model.generate(**batch, max_new_tokens=500)
        # Unload GPU
        batch['input_ids'].to('cpu')
        batch['attention_mask'].to('cpu')
        batch['pixel_values'].to('cpu')

        generated_text = processor.batch_decode(output, skip_special_tokens=True)
        output = output.to('cpu')

        for text in generated_text:
            text_output = text.split("ASSISTANT:")[-1]
            text_embeddings = embedder.encode(text_output)
            maurice_description.append(text_output)
            maurice_embeddings.append(text_embeddings)

    return '\n---\n'.join(maurice_description), dict(text_embeddings=maurice_embeddings)
    # inputs = inputs.to(model.device)
    # print()
    # output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
    # generated_text = processor.batch_decode(output, skip_special_tokens=True)
    # text = generated_text.pop()
    # text_output = text.split("ASSISTANT:")[-1]
    # text_embeddings = embedder.encode(text_output)
    #
    # return text_output, dict(text_embeddings=text_embeddings)


demo = gr.Interface(
    fn=text_to_image,
    inputs=[
        gr.Image(label='Select an image to analyze', type='pil'),
        gr.Textbox(label='Enter Prompt'),
        gr.Number(label='How many duplications of the image (to test memory load)', value=0)
    ],
    outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
)

if __name__ == "__main__":
    demo.launch(show_api=False)