Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
import glob | |
import os | |
import tempfile | |
from collections import defaultdict | |
from io import BytesIO | |
from typing import Any, Dict, List, Optional, Union | |
import cv2 | |
import numpy as np | |
import PIL | |
import PIL.Image | |
import requests | |
from transformers import PretrainedConfig | |
from pydub import AudioSegment | |
from llava.constants import MEDIA_TOKENS | |
from llava.media import Image, Video, Speech, Sound | |
from llava.utils import make_list | |
from llava.utils.logging import logger | |
import torch | |
import whisper | |
import soundfile as sf | |
from librosa import resample as librosa_resample | |
from transformers import AutoFeatureExtractor | |
import math | |
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler | |
import kaldiio | |
# wav_processor = AutoFeatureExtractor.from_pretrained('pretrained_models/AF-Whisper') | |
wav_processor = AutoFeatureExtractor.from_pretrained('Qwen/Qwen2-Audio-7B') | |
__all__ = ["extract_media"] | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: | |
if isinstance(image, Image): | |
if image.path.startswith("http://") or image.path.startswith("https://"): | |
image = PIL.Image.open(requests.get(image.path, stream=True).raw) | |
else: | |
image = PIL.Image.open(image.path) | |
return image | |
def _load_video_bytesio(video_bytesio: BytesIO, *, num_frames: int) -> List[PIL.Image.Image]: | |
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: | |
temp_video.write(video_bytesio.read()) | |
temp_video_name = temp_video.name | |
return _load_video(temp_video_name, num_frames=num_frames) | |
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: | |
# Load video frames from a directory | |
if os.path.isdir(video_path): | |
frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) | |
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) | |
return [PIL.Image.open(frame_paths[index]) for index in indices] | |
# Load video frames from a video file | |
vidcap = cv2.VideoCapture(video_path) | |
# Find the last frame as frame count might not be accurate | |
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
while frame_count > 0: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) | |
if vidcap.grab(): | |
break | |
frame_count -= 1 | |
else: | |
raise ValueError(f"Video '{video_path}' has no frames.") | |
# Extract frames uniformly | |
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) | |
frames = {} | |
for index in indices: | |
if index in frames: | |
continue | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) | |
success, frame = vidcap.read() | |
if not success: | |
logger.warning(f"Failed to read frame {index} from video '{video_path}'. Skipped.") | |
continue | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames[index] = PIL.Image.fromarray(frame) | |
return [frames[index] for index in indices if index in frames] | |
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: | |
num_frames = config.num_video_frames | |
if getattr(config, "fps") != 0: | |
logger.warning("Extracting frames from video with specified FPS is not supported yet. Ignored.") | |
if isinstance(video.path, BytesIO): | |
frames = _load_video_bytesio(video.path, num_frames=num_frames) | |
else: | |
frames = _load_video(video.path, num_frames=num_frames) | |
return frames | |
def _load_speech(speech_path: str): | |
# Load video frames from a directory | |
if speech_path is None: | |
return None | |
speech_outputs = [] | |
speech = whisper.load_audio(speech_path) | |
speech = whisper.pad_or_trim(speech) | |
mel = whisper.log_mel_spectrogram(speech) | |
speech_outputs.append(mel.unsqueeze(0)) | |
speech_frames = torch.stack(speech_outputs, dim=0) | |
return speech_frames.numpy().tolist() | |
def _extract_speech(speech: Speech, config: PretrainedConfig): | |
frames = _load_speech(speech.path) | |
return frames | |
def _get_num_windows(T, sr): | |
window_length = int(30.0 * sr) | |
window_overlap = int(0.0 * sr) | |
max_num_window = 20 | |
num_windows = 1 | |
if T <= window_length: | |
num_windows = 1 | |
full_length = window_length | |
elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): | |
num_windows = max_num_window | |
full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) | |
else: | |
num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) | |
full_length = num_windows * window_length - (num_windows - 1) * window_overlap | |
return num_windows, full_length | |
def _load_audio(file_path, target_sr=16000, duration=30.0, start=0.0): | |
if file_path.endswith('.mp3'): | |
audio = AudioSegment.from_file(file_path) | |
if len(audio) > (start + duration) * 1000: | |
audio = audio[start * 1000:(start + duration) * 1000] | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
data = np.array(audio.get_array_of_samples()) | |
if audio.sample_width == 2: | |
data = data.astype(np.float32) / np.iinfo(np.int16).max | |
elif audio.sample_width == 4: | |
data = data.astype(np.float32) / np.iinfo(np.int32).max | |
else: | |
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
else: | |
with sf.SoundFile(file_path) as audio: | |
original_sr = audio.samplerate | |
channels = audio.channels | |
max_frames = int((start + duration) * original_sr) | |
audio.seek(int(start * original_sr)) | |
frames_to_read = min(max_frames, len(audio)) | |
data = audio.read(frames_to_read) | |
if data.max() > 1 or data.min() < -1: | |
data = data / max(abs(data.max()), abs(data.min())) | |
if original_sr != target_sr: | |
if channels == 1: | |
data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
else: | |
data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
else: | |
if channels != 1: | |
data = data.T[0] | |
if data.min() >= 0: | |
data = 2 * data / abs(data.max()) - 1.0 | |
else: | |
data = data / max(abs(data.max()), abs(data.min())) | |
assert len(data.shape) == 1, data.shape | |
return data | |
def _load_sound_mask(sound_file, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=20, audio_start = 0.0): | |
if sound_file is None: | |
return None | |
window_length = int(window_length * sample_rate) | |
window_overlap = int(window_overlap * sample_rate) | |
max_num_window = int(max_num_window) | |
duration = max_num_window * (window_length - window_overlap) + window_overlap | |
sound_outputs = [] | |
audio_feature_masks = [] | |
audio_embed_masks = [] | |
try: | |
audio_data = _load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration | |
T = len(audio_data) | |
audio_data = audio_data.reshape(1, -1) | |
num_windows, full_length = _get_num_windows(T, sample_rate) | |
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() | |
for i in range(num_windows): | |
audio_embed_mask = torch.zeros(750) | |
start = i * (window_length - window_overlap) | |
audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] | |
orig_length = audio_data_tensor_this.shape[1] | |
audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") # | |
sound_outputs.append(audio_data_tensor_this["input_features"]) | |
# calculate the mask for the input melspec to Whisper | |
melspec_frames_this_window = int(math.ceil(orig_length / 160)) | |
feature_attention_mask = torch.zeros(3000, dtype=torch.int32) | |
feature_attention_mask[:melspec_frames_this_window] = 1 | |
audio_feature_masks.append(feature_attention_mask.unsqueeze(0)) | |
# calculate the mask for the output embedding for use in AF3 | |
conv_lengths = (melspec_frames_this_window - 1) // 2 + 1 | |
output_embedding_lengths = (conv_lengths - 2) // 2 + 1 | |
audio_embed_mask[:output_embedding_lengths] = 1 | |
audio_embed_masks.append(audio_embed_mask) | |
except: | |
print("Error loading sound file: ", sound_file) | |
sound_outputs.append(torch.zeros(1,128,3000)) | |
audio_feature_masks.append(torch.zeros(1, 3000, dtype=torch.int32)) | |
audio_embed_masks.append(torch.zeros(750)) | |
sound_outputs = torch.stack(sound_outputs, dim=0) | |
audio_feature_masks = torch.stack(audio_feature_masks, dim=0) | |
audio_embed_masks = torch.stack(audio_embed_masks, dim=0) | |
return sound_outputs.numpy().tolist(), audio_feature_masks ,audio_embed_masks | |
def _extract_sound_mask(sound: Sound, config: PretrainedConfig): | |
frames, audio_feature_masks, audio_embed_masks = _load_sound_mask(sound.path) | |
return frames, audio_feature_masks, audio_embed_masks | |
def extract_media( | |
messages: List[Dict[str, Any]], | |
config: Optional[PretrainedConfig] = None, | |
draft: bool = False, | |
) -> Dict[str, List[Any]]: | |
media = defaultdict(list) | |
media_meta = defaultdict(list) | |
for message in messages: | |
text = "" | |
print(message) | |
for part in make_list(message["value"]): | |
if isinstance(part, str): | |
for token in MEDIA_TOKENS.values(): | |
if token in part: | |
logger.warning(f"Media token '{token}' found in text: '{part}'. Removed.") | |
part = part.replace(token, "").strip() | |
text += part | |
if isinstance(part, (Image, PIL.Image.Image)): | |
if draft: | |
media["image"].append(part) | |
else: | |
media["image"].append(_extract_image(part)) | |
text += MEDIA_TOKENS["image"] | |
if isinstance(part, Video): | |
if draft: | |
media["video"].append(part) | |
else: | |
media["video"].append(_extract_video(part, config)) | |
text += MEDIA_TOKENS["video"] | |
if isinstance(part, Speech): | |
if draft: | |
media["speech"].append(part) | |
else: | |
media["speech"].append(_extract_speech(part, config)) | |
text += MEDIA_TOKENS["speech"] | |
if isinstance(part, Sound): | |
if draft: | |
media["sound"].append(part) | |
else: | |
sound, audio_feature_masks,audio_embed_masks = _extract_sound_mask(part, config) | |
media["sound"].append(sound) | |
media_meta["sound_feature_masks"].append(audio_feature_masks) | |
media_meta["sound_embed_masks"].append(audio_embed_masks) | |
text += MEDIA_TOKENS["sound"] * len(sound) | |
message["value"] = text | |
return media, media_meta | |