Spaces:
Runtime error
Runtime error
import os | |
import copy | |
import math | |
import warnings | |
import shutil | |
from functools import partial | |
import torch | |
import numpy as np | |
from .model import load_pretrained_model | |
from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch | |
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN | |
from .model.rynnec_qwen2 import Videollama3Qwen2Processor | |
def disable_torch_init(): | |
""" | |
Disable the redundant torch default initialization to accelerate model creation. | |
""" | |
import torch | |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | |
def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs): | |
model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path | |
model_name = get_model_name_from_path(model_path) | |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs) | |
if max_visual_tokens is not None: | |
image_processor.max_tokens = max_visual_tokens | |
if min_visual_tokens is not None: | |
image_processor.min_tokens = min_visual_tokens | |
if tokenizer.pad_token is None and tokenizer.unk_token is not None: | |
tokenizer.pad_token = tokenizer.unk_token | |
processor = Videollama3Qwen2Processor(image_processor, tokenizer) | |
return model, processor | |
def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs): | |
mask_ids = kwargs.pop('mask_ids', None) | |
masks = kwargs.pop('masks', None) | |
if modal == 'image': | |
modal_token = DEFAULT_IMAGE_TOKEN | |
images = images_or_videos | |
timestamps = None | |
elif modal == 'video': | |
modal_token = DEFAULT_VIDEO_TOKEN | |
images, timestamps = images_or_videos | |
elif modal == 'text': | |
modal_token = '' | |
else: | |
raise ValueError(f"Unsupported modal: {modal}") | |
# 1. text preprocess (tag process & generate prompt). | |
if isinstance(instruct, str): | |
messages = [{'role': 'user', 'content': instruct}] | |
elif isinstance(instruct, list): | |
messages = copy.deepcopy(instruct) | |
else: | |
raise ValueError(f"Unsupported type of instruct: {type(instruct)}") | |
if all(not modal_token in message["content"] for message in messages): | |
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") | |
messages[0]["content"] = modal_token + messages[0]["content"] | |
converted_messages = [] | |
for message in messages: | |
chunks = message["content"].split(modal_token) | |
converted_messages.append({ | |
"role": "user", | |
"content": [] | |
}) | |
for chunk_idx in range(1, 2 * len(chunks)): | |
if chunk_idx % 2 == 1: | |
chunk = chunks[chunk_idx // 2].strip() | |
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None | |
else: | |
if modal == 'image': | |
converted_messages[-1]["content"].append({"type": "image"}) | |
elif modal == 'video': | |
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) | |
messages = converted_messages | |
system_message = [] | |
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) | |
# TODO: attention mask? | |
messages = system_message + messages | |
data_dict = vlprocessor( | |
images=images, | |
text=messages, | |
merge_size=image_downsampling, | |
return_labels=True, | |
return_tensors="pt", | |
) | |
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
# images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]] | |
# grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]] | |
# 3. generate response according to visual signals and prompts. | |
keywords = [tokenizer.eos_token] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) | |
do_sample = kwargs.get('do_sample', False) | |
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) | |
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) | |
top_k = kwargs.get('top_k', 20 if do_sample else 50) | |
max_new_tokens = kwargs.get('max_new_tokens', 2048) | |
data_dict["modals"] = [modal] | |
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} | |
if "pixel_values" in data_dict: | |
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) | |
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), | |
pixel_values=data_dict["pixel_values"], | |
grid_sizes=data_dict["grid_sizes"], | |
merge_sizes=data_dict["merge_sizes"], | |
modals=data_dict["modals"], | |
do_sample=do_sample, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria], | |
pad_token_id=tokenizer.eos_token_id, | |
masks=[masks], | |
mask_ids=mask_ids | |
) | |
outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return outputs | |
def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs): | |
image2maskids = kwargs.get('image2maskids', []) | |
img_size=1024 | |
sam_transform = DirectResize(img_size) | |
if modal == 'image': | |
modal_token = DEFAULT_IMAGE_TOKEN | |
images = images_or_videos | |
timestamps = None | |
elif modal == 'video': | |
modal_token = DEFAULT_VIDEO_TOKEN | |
images, timestamps = images_or_videos | |
elif modal == 'text': | |
modal_token = '' | |
else: | |
raise ValueError(f"Unsupported modal: {modal}") | |
sam_images = [] | |
sam_size = None | |
if len(images)>0: | |
for image in images: | |
sam_image = sam_transform.apply_image(np.array(image)) | |
sam_images.append(sam_image) | |
if sam_size is None: | |
sam_size = sam_image.shape[:2] | |
sam_images = np.array(sam_images) | |
sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous() | |
sam_images = sam_preprocess_batch(sam_images) | |
# 1. text preprocess (tag process & generate prompt). | |
if isinstance(instruct, str): | |
messages = [{'role': 'user', 'content': instruct}] | |
elif isinstance(instruct, list): | |
messages = copy.deepcopy(instruct) | |
else: | |
raise ValueError(f"Unsupported type of instruct: {type(instruct)}") | |
if all(not modal_token in message["content"] for message in messages): | |
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") | |
messages[0]["content"] = modal_token + messages[0]["content"] | |
converted_messages = [] | |
for message in messages: | |
chunks = message["content"].split(modal_token) | |
converted_messages.append({ | |
"role": "user", | |
"content": [] | |
}) | |
for chunk_idx in range(1, 2 * len(chunks)): | |
if chunk_idx % 2 == 1: | |
chunk = chunks[chunk_idx // 2].strip() | |
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None | |
else: | |
if modal == 'image': | |
converted_messages[-1]["content"].append({"type": "image"}) | |
elif modal == 'video': | |
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) | |
messages = converted_messages | |
system_message = [] | |
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) | |
# TODO: attention mask? | |
messages = system_message + messages | |
data_dict = vlprocessor( | |
images=images, | |
text=messages, | |
merge_size=image_downsampling, | |
return_labels=True, | |
return_tensors="pt", | |
) | |
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
keywords = [tokenizer.eos_token] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) | |
do_sample = kwargs.get('do_sample', False) | |
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) | |
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) | |
top_k = kwargs.get('top_k', 20 if do_sample else 50) | |
max_new_tokens = kwargs.get('max_new_tokens', 2048) | |
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 | |
data_dict["modals"] = [modal] | |
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} | |
if "pixel_values" in data_dict: | |
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) | |
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) | |
with torch.inference_mode(): | |
output_ids, pred_masks = model.inference( | |
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), | |
pixel_values=data_dict["pixel_values"], | |
grid_sizes=data_dict["grid_sizes"], | |
merge_sizes=data_dict["merge_sizes"], | |
modals=data_dict["modals"], | |
sam_images=[sam_images], | |
sam_size=[sam_size], | |
image2maskids=[image2maskids], | |
do_sample=do_sample, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria], | |
pad_token_id=tokenizer.eos_token_id, | |
seg_start_idx=seg_start_idx | |
) | |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | |
pred_masks_sigmoid = pred_masks.sigmoid()>0.5 | |
return outputs, pred_masks_sigmoid | |