Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import os | |
| import sys | |
| import termcolor | |
| parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") | |
| parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") | |
| args, files_to_check = parser.parse_known_args() | |
| def get_header(ext: str = "py", old: str | bool = False) -> list[str]: | |
| header = [ | |
| "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", | |
| "SPDX-License-Identifier: Apache-2.0", | |
| "", | |
| 'Licensed under the Apache License, Version 2.0 (the "License");', | |
| "you may not use this file except in compliance with the License.", | |
| "You may obtain a copy of the License at", | |
| "", | |
| "http://www.apache.org/licenses/LICENSE-2.0", | |
| "", | |
| "Unless required by applicable law or agreed to in writing, software", | |
| 'distributed under the License is distributed on an "AS IS" BASIS,', | |
| "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", | |
| "See the License for the specific language governing permissions and", | |
| "limitations under the License.", | |
| ] | |
| if ext == ".py" and old: | |
| if old == "single": | |
| header = ["'''"] + header + ["'''"] | |
| elif old == "double": | |
| header = ['"""'] + header + ['"""'] | |
| else: | |
| raise NotImplementedError | |
| elif ext in (".py", ".yaml"): | |
| header = [("# " + line if line else "#") for line in header] | |
| elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): | |
| header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] | |
| else: | |
| raise NotImplementedError | |
| return header | |
| def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: | |
| if file.endswith("__init__.py"): | |
| return | |
| ext = os.path.splitext(file)[1] | |
| content = open(file).read().splitlines() | |
| header = get_header(ext=ext) | |
| if fix: | |
| if _check_header(content, header): | |
| return | |
| print(f"fixing: {file}") | |
| while len(content) > 0 and not content[0]: | |
| content.pop(0) | |
| content = header + [""] + content | |
| with open(file, "w") as file_obj: | |
| for line in content: | |
| file_obj.write(line + "\n") | |
| else: | |
| if not _check_header(content, header): | |
| bad_header = colorize("BAD HEADER", color="red", bold=True) | |
| print(f"{bad_header}: {file}") | |
| results[file] = 1 | |
| else: | |
| results[file] = 0 | |
| def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: | |
| files = os.listdir(path) | |
| for file in files: | |
| full_path = os.path.join(path, file) | |
| if os.path.isdir(full_path): | |
| traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) | |
| elif os.path.isfile(full_path): | |
| ext = os.path.splitext(file)[1] | |
| to_skip = any(substr in full_path for substr in substrings_to_skip) | |
| if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): | |
| apply_file(full_path, results, fix=fix) | |
| else: | |
| raise NotImplementedError | |
| def _check_header(content: list[str], header: list[str]) -> bool: | |
| if content[: len(header)] != header: | |
| return False | |
| i = len(header) | |
| blank_line_count = 0 | |
| while i < len(content) and content[i].strip() == "": | |
| blank_line_count += 1 | |
| i += 1 | |
| # Allow at most two blank lines | |
| if blank_line_count > 2: | |
| return False | |
| # Must have at least one non-empty line after the blank lines | |
| return i < len(content) | |
| def colorize(x: str, color: str, bold: bool = False) -> str: | |
| return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) # type: ignore | |
| if __name__ == "__main__": | |
| if not files_to_check: | |
| files_to_check = [ | |
| "cosmos_predict1/auxiliary", | |
| "cosmos_predict1/diffusion", | |
| "cosmos_predict1/callbacks", | |
| "cosmos_predict1/checkpointer", | |
| "cosmos_predict1/autoregressive", | |
| "cosmos_predict1/tokenizer", | |
| "cosmos_predict1/utils", | |
| ] | |
| for file in files_to_check: | |
| assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" | |
| substrings_to_skip = ["prompt_upsampler"] | |
| results = dict() | |
| for file in files_to_check: | |
| if os.path.isfile(file): | |
| apply_file(file, results, fix=args.fix) | |
| elif os.path.isdir(file): | |
| traverse_directory(file, results, fix=args.fix, substrings_to_skip=substrings_to_skip) | |
| else: | |
| raise NotImplementedError | |
| if any(results.values()): | |
| sys.exit(1) | |