test / app.py
fthor's picture
Fixed some missing operations to process batches
f8fd25d
raw
history blame
4.4 kB
from copy import deepcopy
import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
embedder = SentenceTransformer('all-mpnet-base-v2')
model_id = "llava-hf/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
# use_flash_attention_2=True,
low_cpu_mem_usage=True
)
MAXIMUM_PIXEL_VALUES = 3725568
def text_to_image(image, prompt, duplications: float):
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
image_batch = [image]
prompt_batch = [prompt]
for _ in range(int(duplications)):
image_batch.append(deepcopy(image))
prompt_batch.append(prompt)
inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
batched_inputs :list[dict[str, torch.Tensor]] = list()
if inputs['pixel_values'].flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
i = 0
while i < len(inputs['pixel_values']):
batch['input_ids'].append(inputs['input_ids'][i])
batch['attention_mask'].append(inputs['attention_mask'][i])
batch['pixel_values'].append(inputs['pixel_values'][i])
if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
f'16GB GPU. Will split in more batches')
# Remove the last added image because it's too big to process
batch['input_ids'].pop()
batch['attention_mask'].pop()
batch['pixel_values'].pop()
# transform lists to tensors
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
i += 1
else:
batched_inputs.append(inputs)
maurice_description = list()
maurice_embeddings = list()
for batch in batched_inputs:
# Load on device
batch['input_ids'].to(model.device)
batch['attention_mask'].to(model.device)
batch['pixel_values'].to(model.device)
output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
# Unload GPU
batch['input_ids'].to('cpu')
batch['attention_mask'].to('cpu')
batch['pixel_values'].to('cpu')
generated_text = processor.batch_decode(output, skip_special_tokens=True)
output = output.to('cpu')
for text in generated_text:
text_output = text.split("ASSISTANT:")[-1]
text_embeddings = embedder.encode(text_output).to('cpu')
maurice_description.append(text_output)
maurice_embeddings.append(text_embeddings)
return '\n---\n'.join(maurice_description), dict(text_embeddings=maurice_embeddings)
# inputs = inputs.to(model.device)
# print()
# output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
# generated_text = processor.batch_decode(output, skip_special_tokens=True)
# text = generated_text.pop()
# text_output = text.split("ASSISTANT:")[-1]
# text_embeddings = embedder.encode(text_output)
#
# return text_output, dict(text_embeddings=text_embeddings)
demo = gr.Interface(
fn=text_to_image,
inputs=[
gr.Image(label='Select an image to analyze', type='pil'),
gr.Textbox(label='Enter Prompt'),
gr.Number(label='How many duplications of the image (to test memory load)', value=0)
],
outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
)
if __name__ == "__main__":
demo.launch(show_api=False)