|
|
|
from io import BytesIO |
|
import av |
|
import base64 |
|
from PIL import Image |
|
from typing import List |
|
from dataclasses import dataclass |
|
|
|
def sample(N, K): |
|
array = list(range(N)) |
|
length = len(array) |
|
if K >= length or K<2: |
|
return array |
|
|
|
k = length // K |
|
sampled_points = [array[i] for i in range(0, length, k)][:K-1] |
|
sampled_points.append(array[-1]) |
|
return sampled_points |
|
|
|
def grid_sample(array, N, K): |
|
group_size, remainder = len(array) // K, len(array) % K |
|
sampled_groups = [] |
|
|
|
for i in range(K): |
|
s = i * group_size + min(i, remainder) |
|
e = s + group_size + (1 if i < remainder else 0) |
|
group = array[s:e] |
|
|
|
if N >= len(group): |
|
sampled_groups.append(group) |
|
else: |
|
interval = len(group) // N |
|
sampled_groups.append([group[j * interval] for j in range(N)]) |
|
|
|
return sampled_groups |
|
|
|
@dataclass |
|
class VideoProcessor: |
|
frame_format: str = "JPEG" |
|
frame_limit: int = 1 |
|
|
|
def _decode(self, video_path: str) -> List[Image.Image]: |
|
frames = [] |
|
with av.open(video_path) as container: |
|
src = container.streams.video[0] |
|
time_base = src.time_base |
|
framerate = src.average_rate |
|
|
|
for i in sample(src.frames, self.frame_limit): |
|
n = round((i / framerate) / time_base) |
|
container.seek(n, backward=True, stream=src) |
|
frame = next(container.decode(video=0)) |
|
im = frame.to_image() |
|
frames.append(im) |
|
return frames |
|
|
|
def decode(self, video_path: str) -> List[Image.Image]: |
|
frames = [] |
|
container = av.open(video_path) |
|
for i, frame in enumerate(container.decode(video=0)): |
|
if i % self.frame_skip: |
|
continue |
|
im = frame.to_image() |
|
frames.append(im) |
|
return frames |
|
|
|
def concatenate(self, frames: List[Image.Image], direction: str = "horizontal") -> Image.Image: |
|
widths, heights = zip(*(frame.size for frame in frames)) |
|
|
|
if direction == "horizontal": |
|
total_width = sum(widths) |
|
max_height = max(heights) |
|
concatenated_image = Image.new('RGB', (total_width, max_height)) |
|
x_offset = 0 |
|
for frame in frames: |
|
concatenated_image.paste(frame, (x_offset, 0)) |
|
x_offset += frame.width |
|
else: |
|
max_width = max(widths) |
|
total_height = sum(heights) |
|
concatenated_image = Image.new('RGB', (max_width, total_height)) |
|
y_offset = 0 |
|
for frame in frames: |
|
concatenated_image.paste(frame, (0, y_offset)) |
|
y_offset += frame.height |
|
|
|
return concatenated_image |
|
|
|
def grid_concatenate(self, frames: List[Image.Image], group_size, limit=10) -> List[Image.Image]: |
|
sampled_groups = grid_sample(frames, group_size, limit) |
|
return [self.concatenate(group) for group in sampled_groups] |
|
|
|
def to_base64_list(self, images: List[Image.Image]) -> List[str]: |
|
base64_list = [] |
|
for image in images: |
|
buffered = BytesIO() |
|
image.save(buffered, format=self.frame_format) |
|
base64_list.append(base64.b64encode(buffered.getvalue()).decode('utf-8')) |
|
return base64_list |
|
|