|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import numpy as np |
|
import os |
|
|
|
import torch |
|
from transformers import StoppingCriteria |
|
from .constants import IMAGE_TOKEN_INDEX |
|
|
|
import tempfile |
|
from io import BytesIO |
|
|
|
|
|
def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None): |
|
import cv2 |
|
|
|
if fps == None or frame_count == None: |
|
|
|
fps = vidcap.get(cv2.CAP_PROP_FPS) |
|
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
if fps == 0 or frame_count == 0: |
|
print("Video file not found. return empty images.") |
|
return [ |
|
Image.new("RGB", (720, 720)), |
|
] * num_frames |
|
|
|
duration = frame_count / fps |
|
frame_interval = frame_count // num_frames |
|
if frame_interval == 0 and frame_count <= 1: |
|
print("frame_interval is equal to 0. return empty image.") |
|
return [ |
|
Image.new("RGB", (720, 720)), |
|
] * num_frames |
|
|
|
|
|
images = [] |
|
count = 0 |
|
success = True |
|
frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int) |
|
|
|
while success: |
|
|
|
if frame_count >= num_frames: |
|
success, frame = vidcap.read() |
|
if count in frame_indices: |
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
im_pil = Image.fromarray(img) |
|
images.append(im_pil) |
|
if len(images) >= num_frames: |
|
return images |
|
count += 1 |
|
else: |
|
|
|
success, frame = vidcap.read() |
|
if success: |
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
im_pil = Image.fromarray(img) |
|
images.append(im_pil) |
|
count += 1 |
|
elif count >= 1: |
|
width, height = images[-1].size |
|
images = [Image.new("RGB", (width, height))] * \ |
|
(num_frames - len(images)) + images |
|
print("padding frames:", (num_frames - len(images))) |
|
return images |
|
else: |
|
break |
|
raise ValueError( |
|
"Did not find enough frames in the video. return empty image.") |
|
|
|
|
|
def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None): |
|
""" |
|
Extract frames from a video using OpenCV. |
|
|
|
Args: |
|
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. |
|
frames (int): Number of frames to extract from the video. |
|
|
|
Returns: |
|
list: List of PIL Images extracted from the video. |
|
|
|
Raises: |
|
NotImplementedError: If the type of `vpath_or_bytesio` is not supported. |
|
""" |
|
import cv2 |
|
|
|
if isinstance(vpath_or_bytesio, str): |
|
vidcap = cv2.VideoCapture(vpath_or_bytesio) |
|
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) |
|
elif isinstance(vpath_or_bytesio, (BytesIO,)): |
|
|
|
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: |
|
temp_video.write(vpath_or_bytesio.read()) |
|
temp_video_name = temp_video.name |
|
vidcap = cv2.VideoCapture(temp_video_name) |
|
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) |
|
else: |
|
raise NotImplementedError(type(vpath_or_bytesio)) |
|
|
|
|
|
def load_image_from_base64(image): |
|
return Image.open(BytesIO(base64.b64decode(image))) |
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
""" |
|
Expand the given PIL image to a square shape by adding padding. |
|
|
|
Parameters: |
|
- pil_img: The PIL image to be expanded. |
|
- background_color: The color of the padding to be added. |
|
|
|
Returns: |
|
- The expanded PIL image. |
|
|
|
If the image is already square, it is returned as is. |
|
If the image is wider than it is tall, padding is added to the top and bottom. |
|
If the image is taller than it is wide, padding is added to the left and right. |
|
""" |
|
width, height = pil_img.size |
|
if pil_img.mode == 'L': |
|
background_color = background_color[0] |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
|
|
def process_image(image_file, data_args, image_folder, pil_preprocess_fn=None): |
|
processor = data_args.image_processor |
|
if isinstance(image_file, str): |
|
if image_folder is not None: |
|
image = Image.open(os.path.join( |
|
image_folder, image_file)).convert("RGB") |
|
else: |
|
image = Image.open(image_file).convert("RGB") |
|
else: |
|
|
|
image = image_file.convert("RGB") |
|
|
|
info = None |
|
|
|
if pil_preprocess_fn is not None: |
|
image = pil_preprocess_fn(image) |
|
if isinstance(image, tuple): |
|
image, info = image |
|
|
|
if data_args.image_aspect_ratio == "resize": |
|
if hasattr(data_args.image_processor, "crop_size"): |
|
|
|
crop_size = data_args.image_processor.crop_size |
|
else: |
|
|
|
assert hasattr(data_args.image_processor, "size") |
|
crop_size = data_args.image_processor.size |
|
image = image.resize((crop_size["height"], crop_size["width"])) |
|
if data_args.image_aspect_ratio == "pad": |
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new( |
|
pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new( |
|
pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
image = expand2square(image, tuple(int(x * 255) |
|
for x in processor.image_mean)) |
|
image = processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values"][0] |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
image = processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values"][0] |
|
if info is not None: |
|
return image, info |
|
return image |
|
|
|
|
|
def process_images(images, image_processor, model_cfg): |
|
|
|
model_cfg.image_processor = image_processor |
|
new_images = [process_image(image, model_cfg, None) for image in images] |
|
|
|
if all(x.shape == new_images[0].shape for x in new_images): |
|
new_images = torch.stack(new_images, dim=0) |
|
return new_images |
|
|
|
|
|
|
|
def tokenizer_image_token( |
|
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None |
|
): |
|
prompt_chunks = [ |
|
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
offset = 0 |
|
if ( |
|
len(prompt_chunks) > 0 |
|
and len(prompt_chunks[0]) > 0 |
|
and prompt_chunks[0][0] == tokenizer.bos_token_id |
|
): |
|
offset = 1 |
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
input_ids.extend(x[offset:]) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == "pt": |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
return input_ids |
|
|
|
|
|
def is_gemma_tokenizer(tokenizer): |
|
return "gemma" in tokenizer.__class__.__name__.lower() |
|
|
|
|
|
def get_model_name_from_path(model_path): |
|
model_path = model_path.strip("/") |
|
model_paths = model_path.split("/") |
|
if model_paths[-1].startswith("checkpoint-"): |
|
return model_paths[-2] + "_" + model_paths[-1] |
|
else: |
|
return model_paths[-1] |
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, keywords, tokenizer, input_ids): |
|
self.keywords = keywords |
|
self.keyword_ids = [] |
|
self.max_keyword_len = 0 |
|
for keyword in keywords: |
|
cur_keyword_ids = tokenizer(keyword).input_ids |
|
if ( |
|
len(cur_keyword_ids) > 1 |
|
and cur_keyword_ids[0] == tokenizer.bos_token_id |
|
): |
|
cur_keyword_ids = cur_keyword_ids[1:] |
|
if len(cur_keyword_ids) > self.max_keyword_len: |
|
self.max_keyword_len = len(cur_keyword_ids) |
|
self.keyword_ids.append(torch.tensor(cur_keyword_ids)) |
|
self.tokenizer = tokenizer |
|
self.start_len = input_ids.shape[1] |
|
|
|
def call_for_batch( |
|
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
|
) -> bool: |
|
offset = min(output_ids.shape[1] - |
|
self.start_len, self.max_keyword_len) |
|
self.keyword_ids = [ |
|
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids |
|
] |
|
for keyword_id in self.keyword_ids: |
|
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): |
|
return True |
|
outputs = self.tokenizer.batch_decode( |
|
output_ids[:, -offset:], skip_special_tokens=True |
|
)[0] |
|
for keyword in self.keywords: |
|
if keyword in outputs: |
|
return True |
|
return False |
|
|
|
def __call__( |
|
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
|
) -> bool: |
|
outputs = [] |
|
for i in range(output_ids.shape[0]): |
|
outputs.append(self.call_for_batch( |
|
output_ids[i].unsqueeze(0), scores)) |
|
return all(outputs) |
|
|