arjunanand13 commited on
Commit
dae3891
·
verified ·
1 Parent(s): 4303813

Update image_caption.py

Browse files
Files changed (1) hide show
  1. image_caption.py +28 -2
image_caption.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
  import torch
6
  from PIL import Image
7
-
8
 
9
  class Caption:
10
  def __init__(self):
@@ -19,7 +19,7 @@ class Caption:
19
  )
20
 
21
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- self.device = torch.device("cpu")
23
  self.model.to(self.device)
24
  self.max_length = 16
25
  self.num_beams = 4
@@ -45,6 +45,32 @@ class Caption:
45
  preds = [pred.strip() for pred in preds]
46
  return preds
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def get_args(self):
49
  parser = argparse.ArgumentParser()
50
  parser.add_argument( "-i",
 
4
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
  import torch
6
  from PIL import Image
7
+ import io
8
 
9
  class Caption:
10
  def __init__(self):
 
19
  )
20
 
21
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  self.model.to(self.device)
24
  self.max_length = 16
25
  self.num_beams = 4
 
45
  preds = [pred.strip() for pred in preds]
46
  return preds
47
 
48
+ def predict_from_memory(self, image_buffers):
49
+ images = []
50
+
51
+ for image_buffer in image_buffers:
52
+ # Ensure the buffer is positioned at the start
53
+ if isinstance(image_buffer, io.BytesIO):
54
+ image_buffer.seek(0)
55
+ try:
56
+ i_image = Image.open(image_buffer)
57
+ if i_image.mode != "RGB":
58
+ i_image = i_image.convert("RGB")
59
+ images.append(i_image)
60
+ except Exception as e:
61
+ print(f"Failed to process image buffer: {str(e)}")
62
+ continue
63
+
64
+ return self.process_images(images)
65
+
66
+ def process_images(self, images):
67
+ pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
68
+ pixel_values = pixel_values.to(self.device)
69
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
70
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
71
+ preds = [pred.strip() for pred in preds]
72
+ return preds
73
+
74
  def get_args(self):
75
  parser = argparse.ArgumentParser()
76
  parser.add_argument( "-i",