File size: 4,403 Bytes
ed1cd13
 
41e6903
 
 
 
 
854f0cf
 
a76b117
 
 
 
41e6903
854f0cf
41e6903
 
 
 
a76b117
3ac1ccb
bc91b52
3ac1ccb
41e6903
 
b083d4d
41e6903
b083d4d
41e6903
 
ed1cd13
 
b083d4d
ed1cd13
 
 
b083d4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8fd25d
 
 
b083d4d
 
 
f8fd25d
b083d4d
 
f8fd25d
 
b083d4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854f0cf
b083d4d
 
 
 
 
 
 
 
 
 
dacd4b7
41e6903
 
 
 
 
ed1cd13
 
41e6903
854f0cf
41e6903
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from copy import deepcopy

import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig

from sentence_transformers import SentenceTransformer, util

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

embedder = SentenceTransformer('all-mpnet-base-v2')
model_id = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
    # use_flash_attention_2=True,
    low_cpu_mem_usage=True
)

MAXIMUM_PIXEL_VALUES = 3725568

def text_to_image(image, prompt, duplications: float):
    prompt = f'USER: <image>\n{prompt}\nASSISTANT:'

    image_batch = [image]
    prompt_batch = [prompt]
    for _ in range(int(duplications)):
        image_batch.append(deepcopy(image))
        prompt_batch.append(prompt)

    inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")

    batched_inputs :list[dict[str, torch.Tensor]] = list()
    if inputs['pixel_values'].flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
        batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
        i = 0
        while i < len(inputs['pixel_values']):
            batch['input_ids'].append(inputs['input_ids'][i])
            batch['attention_mask'].append(inputs['attention_mask'][i])
            batch['pixel_values'].append(inputs['pixel_values'][i])

            if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
                print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
                      f'16GB GPU. Will split in more batches')
                # Remove the last added image because it's too big to process
                batch['input_ids'].pop()
                batch['attention_mask'].pop()
                batch['pixel_values'].pop()

                # transform lists to tensors
                batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
                batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
                batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)

                # Add to the batched_inputs
                batched_inputs.append(batch)
                batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
            else:
                i += 1
    else:
        batched_inputs.append(inputs)

    maurice_description = list()
    maurice_embeddings = list()
    for batch in batched_inputs:
        # Load on device
        batch['input_ids'].to(model.device)
        batch['attention_mask'].to(model.device)
        batch['pixel_values'].to(model.device)
        output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)

        # Unload GPU
        batch['input_ids'].to('cpu')
        batch['attention_mask'].to('cpu')
        batch['pixel_values'].to('cpu')

        generated_text = processor.batch_decode(output, skip_special_tokens=True)
        output = output.to('cpu')

        for text in generated_text:
            text_output = text.split("ASSISTANT:")[-1]
            text_embeddings = embedder.encode(text_output).to('cpu')
            maurice_description.append(text_output)
            maurice_embeddings.append(text_embeddings)

    return '\n---\n'.join(maurice_description), dict(text_embeddings=maurice_embeddings)
    # inputs = inputs.to(model.device)
    # print()
    # output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
    # generated_text = processor.batch_decode(output, skip_special_tokens=True)
    # text = generated_text.pop()
    # text_output = text.split("ASSISTANT:")[-1]
    # text_embeddings = embedder.encode(text_output)
    #
    # return text_output, dict(text_embeddings=text_embeddings)


demo = gr.Interface(
    fn=text_to_image,
    inputs=[
        gr.Image(label='Select an image to analyze', type='pil'),
        gr.Textbox(label='Enter Prompt'),
        gr.Number(label='How many duplications of the image (to test memory load)', value=0)
    ],
    outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
)

if __name__ == "__main__":
    demo.launch(show_api=False)