video_ad_classifier / image_caption.py
arjunanand13's picture
Update image_caption.py
dae3891 verified
raw
history blame
3.14 kB
import argparse
from pathlib import Path
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import io
class Caption:
def __init__(self):
self.model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
self.feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
self.tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.max_length = 16
self.num_beams = 4
self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
def predict_step(self,image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def predict_from_memory(self, image_buffers):
images = []
for image_buffer in image_buffers:
# Ensure the buffer is positioned at the start
if isinstance(image_buffer, io.BytesIO):
image_buffer.seek(0)
try:
i_image = Image.open(image_buffer)
if i_image.mode != "RGB":
i_image = i_image.convert("RGB")
images.append(i_image)
except Exception as e:
print(f"Failed to process image buffer: {str(e)}")
continue
return self.process_images(images)
def process_images(self, images):
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def get_args(self):
parser = argparse.ArgumentParser()
parser.add_argument( "-i",
"--input_img_paths",
type=str,
default="farmer.jpg",
help="img for caption")
args = parser.parse_args()
return args
if __name__ == "__main__":
model = Caption()
args = model.get_args()
image_paths = []
image_paths.append(args.input_img_paths)
print(model.predict_step(image_paths))