Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import torch | |
def extract_local_zigzag(value, rank, world_size, device, dim=1): | |
value_chunks = value.chunk(2 * world_size, dim=dim) | |
local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) | |
return local_value.to(device) | |
def extract_local_from_list(value_list, sp_rank, sp_size): | |
quotient, remainder = divmod(len(value_list), sp_size) | |
start_idx = sp_rank * quotient + min(sp_rank, remainder) | |
end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder) | |
return value_list[start_idx:end_idx] | |
def extract_local_from_list_zigzag(value_list, sp_rank, sp_size): | |
chunk_size, remainder = divmod(len(value_list), (2 * sp_size)) | |
value_chunks = [] | |
start_idx = 0 | |
for i in range(2 * sp_size): | |
extra = 1 if i < remainder else 0 | |
end_idx = start_idx + chunk_size + extra | |
value_chunks.append(value_list[start_idx:end_idx]) | |
start_idx = end_idx | |
local_value = value_chunks[sp_rank] + value_chunks[2 * sp_size - sp_rank - 1] | |
return local_value | |
def extract_local_input_ids(input_ids, image_positions, sp_rank, sp_size, bos_token_id=1, image_token_len=3): | |
quotient, remainder = divmod(len(image_positions), sp_size) | |
start_idx = sp_rank * quotient + min(sp_rank, remainder) | |
end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder) | |
start_position_idx = image_positions[start_idx] | |
if sp_rank != sp_size - 1: | |
end_position_idx = image_positions[end_idx] | |
else: | |
end_position_idx = len(input_ids) | |
if sp_rank == 0: # Handle the head of the sequence | |
return input_ids[0:end_position_idx] | |
elif sp_rank == sp_size - 1: # Handle the tail of the sequence | |
return input_ids[start_position_idx:] | |
else: | |
return input_ids[start_position_idx:end_position_idx] | |
def extract_local_position_ids(input_ids, image_positions, image_ids, sp_rank, sp_size, image_token_len=198): | |
quotient, remainder = divmod(len(image_ids), sp_size) | |
start_idx = sp_rank * quotient + min(sp_rank, remainder) | |
end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder) | |
start_position_idx = image_positions[start_idx] + image_ids[start_idx] * image_token_len | |
if sp_rank != sp_size - 1: # Handle the tail of the sequence | |
end_position_idx = image_positions[end_idx] + image_ids[end_idx] * image_token_len # image_token_len + 3 | |
else: | |
end_position_idx = len(input_ids) | |
if sp_rank == 0: # Handle the head of the sequence | |
return input_ids[0:end_position_idx] | |
elif sp_rank == sp_size - 1: # Handle the tail of the sequence | |
return input_ids[start_position_idx:] | |
else: | |
return input_ids[start_position_idx:end_position_idx] | |