RynnEC / rynnec /__init__.py
lixin4ever's picture
Upload (#2)
372785b verified
import os
import copy
import math
import warnings
import shutil
from functools import partial
import torch
import numpy as np
from .model import load_pretrained_model
from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
from .model.rynnec_qwen2 import Videollama3Qwen2Processor
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs):
model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
if max_visual_tokens is not None:
image_processor.max_tokens = max_visual_tokens
if min_visual_tokens is not None:
image_processor.min_tokens = min_visual_tokens
if tokenizer.pad_token is None and tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
processor = Videollama3Qwen2Processor(image_processor, tokenizer)
return model, processor
def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs):
mask_ids = kwargs.pop('mask_ids', None)
masks = kwargs.pop('masks', None)
if modal == 'image':
modal_token = DEFAULT_IMAGE_TOKEN
images = images_or_videos
timestamps = None
elif modal == 'video':
modal_token = DEFAULT_VIDEO_TOKEN
images, timestamps = images_or_videos
elif modal == 'text':
modal_token = ''
else:
raise ValueError(f"Unsupported modal: {modal}")
# 1. text preprocess (tag process & generate prompt).
if isinstance(instruct, str):
messages = [{'role': 'user', 'content': instruct}]
elif isinstance(instruct, list):
messages = copy.deepcopy(instruct)
else:
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
if all(not modal_token in message["content"] for message in messages):
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
messages[0]["content"] = modal_token + messages[0]["content"]
converted_messages = []
for message in messages:
chunks = message["content"].split(modal_token)
converted_messages.append({
"role": "user",
"content": []
})
for chunk_idx in range(1, 2 * len(chunks)):
if chunk_idx % 2 == 1:
chunk = chunks[chunk_idx // 2].strip()
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
else:
if modal == 'image':
converted_messages[-1]["content"].append({"type": "image"})
elif modal == 'video':
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
messages = converted_messages
system_message = []
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
# TODO: attention mask?
messages = system_message + messages
data_dict = vlprocessor(
images=images,
text=messages,
merge_size=image_downsampling,
return_labels=True,
return_tensors="pt",
)
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
# images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
# grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
# 3. generate response according to visual signals and prompts.
keywords = [tokenizer.eos_token]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
do_sample = kwargs.get('do_sample', False)
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
top_k = kwargs.get('top_k', 20 if do_sample else 50)
max_new_tokens = kwargs.get('max_new_tokens', 2048)
data_dict["modals"] = [modal]
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
if "pixel_values" in data_dict:
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
with torch.inference_mode():
output_ids = model.generate(
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
pixel_values=data_dict["pixel_values"],
grid_sizes=data_dict["grid_sizes"],
merge_sizes=data_dict["merge_sizes"],
modals=data_dict["modals"],
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
use_cache=True,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.eos_token_id,
masks=[masks],
mask_ids=mask_ids
)
outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return outputs
def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs):
image2maskids = kwargs.get('image2maskids', [])
img_size=1024
sam_transform = DirectResize(img_size)
if modal == 'image':
modal_token = DEFAULT_IMAGE_TOKEN
images = images_or_videos
timestamps = None
elif modal == 'video':
modal_token = DEFAULT_VIDEO_TOKEN
images, timestamps = images_or_videos
elif modal == 'text':
modal_token = ''
else:
raise ValueError(f"Unsupported modal: {modal}")
sam_images = []
sam_size = None
if len(images)>0:
for image in images:
sam_image = sam_transform.apply_image(np.array(image))
sam_images.append(sam_image)
if sam_size is None:
sam_size = sam_image.shape[:2]
sam_images = np.array(sam_images)
sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous()
sam_images = sam_preprocess_batch(sam_images)
# 1. text preprocess (tag process & generate prompt).
if isinstance(instruct, str):
messages = [{'role': 'user', 'content': instruct}]
elif isinstance(instruct, list):
messages = copy.deepcopy(instruct)
else:
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
if all(not modal_token in message["content"] for message in messages):
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
messages[0]["content"] = modal_token + messages[0]["content"]
converted_messages = []
for message in messages:
chunks = message["content"].split(modal_token)
converted_messages.append({
"role": "user",
"content": []
})
for chunk_idx in range(1, 2 * len(chunks)):
if chunk_idx % 2 == 1:
chunk = chunks[chunk_idx // 2].strip()
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
else:
if modal == 'image':
converted_messages[-1]["content"].append({"type": "image"})
elif modal == 'video':
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
messages = converted_messages
system_message = []
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
# TODO: attention mask?
messages = system_message + messages
data_dict = vlprocessor(
images=images,
text=messages,
merge_size=image_downsampling,
return_labels=True,
return_tensors="pt",
)
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
keywords = [tokenizer.eos_token]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
do_sample = kwargs.get('do_sample', False)
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
top_k = kwargs.get('top_k', 20 if do_sample else 50)
max_new_tokens = kwargs.get('max_new_tokens', 2048)
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
data_dict["modals"] = [modal]
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
if "pixel_values" in data_dict:
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
with torch.inference_mode():
output_ids, pred_masks = model.inference(
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
pixel_values=data_dict["pixel_values"],
grid_sizes=data_dict["grid_sizes"],
merge_sizes=data_dict["merge_sizes"],
modals=data_dict["modals"],
sam_images=[sam_images],
sam_size=[sam_size],
image2maskids=[image2maskids],
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
use_cache=True,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.eos_token_id,
seg_start_idx=seg_start_idx
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
pred_masks_sigmoid = pred_masks.sigmoid()>0.5
return outputs, pred_masks_sigmoid