File size: 5,093 Bytes
ed1cd13 41e6903 854f0cf 31d8777 a76b117 41e6903 854f0cf 41e6903 a76b117 3ac1ccb bc91b52 31d8777 41e6903 b083d4d 41e6903 b083d4d 41e6903 ed1cd13 b083d4d ed1cd13 b083d4d f8fd25d b083d4d f8fd25d b083d4d 31d8777 f8fd25d b083d4d 032b71e 31d8777 b083d4d 032b71e b083d4d 854f0cf b083d4d dacd4b7 41e6903 ed1cd13 41e6903 854f0cf 41e6903 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from copy import deepcopy
import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
from transformers import PretrainedConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
embedder = SentenceTransformer('all-mpnet-base-v2')
model_id = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
# use_flash_attention_2=True,
low_cpu_mem_usage=True,
# config=PretrainedConfig(do_sample=True)
)
MAXIMUM_PIXEL_VALUES = 3725568
def text_to_image(image, prompt, duplications: float):
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
image_batch = [image]
prompt_batch = [prompt]
for _ in range(int(duplications)):
image_batch.append(deepcopy(image))
prompt_batch.append(prompt)
inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
batched_inputs :list[dict[str, torch.Tensor]] = list()
if inputs['pixel_values'].flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
i = 0
while i < len(inputs['pixel_values']):
batch['input_ids'].append(inputs['input_ids'][i])
batch['attention_mask'].append(inputs['attention_mask'][i])
batch['pixel_values'].append(inputs['pixel_values'][i])
if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
f'16GB GPU. Will split in more batches')
# Remove the last added image because it's too big to process
batch['input_ids'].pop()
batch['attention_mask'].pop()
batch['pixel_values'].pop()
# transform lists to tensors
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
i += 1
if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
batched_inputs.append(inputs)
maurice_description = list()
maurice_embeddings = list()
for batch in batched_inputs:
# Load on device
batch['input_ids'] = batch['input_ids'].to(model.device)
batch['attention_mask'] = batch['attention_mask'].to(model.device)
batch['pixel_values'] = batch['pixel_values'].to(model.device)
# output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
output = model.generate(**batch, max_new_tokens=500)
# Unload GPU
batch['input_ids'].to('cpu')
batch['attention_mask'].to('cpu')
batch['pixel_values'].to('cpu')
generated_text = processor.batch_decode(output, skip_special_tokens=True)
output = output.to('cpu')
for text in generated_text:
text_output = text.split("ASSISTANT:")[-1]
text_embeddings = embedder.encode(text_output)
maurice_description.append(text_output)
maurice_embeddings.append(text_embeddings)
return '\n---\n'.join(maurice_description), dict(text_embeddings=maurice_embeddings)
# inputs = inputs.to(model.device)
# print()
# output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
# generated_text = processor.batch_decode(output, skip_special_tokens=True)
# text = generated_text.pop()
# text_output = text.split("ASSISTANT:")[-1]
# text_embeddings = embedder.encode(text_output)
#
# return text_output, dict(text_embeddings=text_embeddings)
demo = gr.Interface(
fn=text_to_image,
inputs=[
gr.Image(label='Select an image to analyze', type='pil'),
gr.Textbox(label='Enter Prompt'),
gr.Number(label='How many duplications of the image (to test memory load)', value=0)
],
outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
)
if __name__ == "__main__":
demo.launch(show_api=False)
|