Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import os | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
def list_has_common_element(list1, list2): | |
set1 = set(list1) | |
set2 = set(list2) | |
return len(set1.intersection(set2)) > 0 | |
def calculate_scale_num(input, row_block, col_block): | |
if len(input.shape) > 2: | |
input = input.reshape(-1, input.shape[2]) | |
elif len(input.shape) == 2: | |
pass | |
else: | |
raise ValueError(f"input shape {input.shape} does not match for block cut, {input}") | |
M, N = input.shape[0], input.shape[1] | |
if row_block == -1: | |
row_block = M | |
if col_block == -1: | |
col_block = N | |
return input.numel() / (row_block * col_block) | |
def quant_get_local_rank() -> int: | |
return int(os.environ.get("LOCAL_RANK") or 0) | |
def format_string_with_condition( | |
input_string, | |
condition_config, | |
symm, | |
bits, | |
blocksize_config, | |
input_pad=20, | |
): | |
padded_string = input_string.ljust(input_pad) | |
output_string = padded_string | |
for k, v in condition_config.items(): | |
if v: | |
output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6) | |
else: | |
output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6) | |
output_string = output_string + f"Symm {symm}".ljust(10) | |
for k, v in bits.items(): | |
output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10) | |
for k, v in blocksize_config.items(): | |
output_string += f"{k}: {v}".ljust(15) | |
return output_string | |
def print_warning(sentence): | |
print("*" * (len(sentence) + 4)) | |
print(f"* {sentence} *") | |
print("*" * (len(sentence) + 4)) | |
def check_nan_inf(tensor, check_nan, check_inf): | |
if check_nan: | |
contain_nan = torch.isnan(tensor).any() | |
else: | |
contain_nan = False | |
if check_inf: | |
contain_inf = torch.isinf(tensor).any() | |
else: | |
contain_inf = False | |
return contain_nan, contain_inf | |
def move_torch_to_numpy(tensor): | |
if tensor is None: | |
return None | |
if tensor.is_cuda: | |
tensor = tensor.cpu() | |
return tensor.detach().float().numpy() | |
def flatten_to_1d(tensor): | |
if tensor is None: | |
return None | |
return tensor.reshape(-1) | |