/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** @file common_host.cu * @author Thomas Müller and Nikolaus Binder, NVIDIA * @brief Common utilities that are needed by pretty much every component of this framework. */ #include #include #include #include #include #include #include #include #include namespace tcnn { static_assert( __CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2), "tiny-cuda-nn requires at least CUDA 10.2" ); std::function g_log_callback = [](LogSeverity severity, const std::string& msg) { switch (severity) { case LogSeverity::Warning: std::cerr << fmt::format("tiny-cuda-nn warning: {}\n", msg); break; case LogSeverity::Error: std::cerr << fmt::format("tiny-cuda-nn error: {}\n", msg); break; default: break; } if (verbose()) { switch (severity) { case LogSeverity::Debug: std::cerr << fmt::format("tiny-cuda-nn debug: {}\n", msg); break; case LogSeverity::Info: std::cerr << fmt::format("tiny-cuda-nn info: {}\n", msg); break; case LogSeverity::Success: std::cerr << fmt::format("tiny-cuda-nn success: {}\n", msg); break; default: break; } } }; const std::function& log_callback() { return g_log_callback; } void set_log_callback(const std::function& cb) { g_log_callback = cb; } bool g_verbose = false; bool verbose() { return g_verbose; } void set_verbose(bool verbose) { g_verbose = verbose; } Activation string_to_activation(const std::string& activation_name) { if (equals_case_insensitive(activation_name, "None")) { return Activation::None; } else if (equals_case_insensitive(activation_name, "ReLU")) { return Activation::ReLU; } else if (equals_case_insensitive(activation_name, "LeakyReLU")) { return Activation::LeakyReLU; } else if (equals_case_insensitive(activation_name, "Exponential")) { return Activation::Exponential; } else if (equals_case_insensitive(activation_name, "Sigmoid")) { return Activation::Sigmoid; } else if (equals_case_insensitive(activation_name, "Sine")) { return Activation::Sine; } else if (equals_case_insensitive(activation_name, "Squareplus")) { return Activation::Squareplus; } else if (equals_case_insensitive(activation_name, "Softplus")) { return Activation::Softplus; } else if (equals_case_insensitive(activation_name, "Tanh")) { return Activation::Tanh; } throw std::runtime_error{fmt::format("Invalid activation name: {}", activation_name)}; } std::string to_string(Activation activation) { switch (activation) { case Activation::None: return "None"; case Activation::ReLU: return "ReLU"; case Activation::LeakyReLU: return "LeakyReLU"; case Activation::Exponential: return "Exponential"; case Activation::Sigmoid: return "Sigmoid"; case Activation::Sine: return "Sine"; case Activation::Squareplus: return "Squareplus"; case Activation::Softplus: return "Softplus"; case Activation::Tanh: return "Tanh"; default: throw std::runtime_error{"Invalid activation."}; } } GridType string_to_grid_type(const std::string& grid_type) { if (equals_case_insensitive(grid_type, "Hash")) { return GridType::Hash; } else if (equals_case_insensitive(grid_type, "Dense")) { return GridType::Dense; } else if (equals_case_insensitive(grid_type, "Tiled") || equals_case_insensitive(grid_type, "Tile")) { return GridType::Tiled; } throw std::runtime_error{fmt::format("Invalid grid type: {}", grid_type)}; } std::string to_string(GridType grid_type) { switch (grid_type) { case GridType::Hash: return "Hash"; case GridType::Dense: return "Dense"; case GridType::Tiled: return "Tiled"; default: throw std::runtime_error{"Invalid grid type."}; } } HashType string_to_hash_type(const std::string& hash_type) { if (equals_case_insensitive(hash_type, "Prime")) { return HashType::Prime; } else if (equals_case_insensitive(hash_type, "CoherentPrime")) { return HashType::CoherentPrime; } else if (equals_case_insensitive(hash_type, "ReversedPrime")) { return HashType::ReversedPrime; } else if (equals_case_insensitive(hash_type, "Rng")) { return HashType::Rng; } else if (equals_case_insensitive(hash_type, "BaseConvert")) { return HashType::BaseConvert; } throw std::runtime_error{fmt::format("Invalid hash type: {}", hash_type)}; } std::string to_string(HashType hash_type) { switch (hash_type) { case HashType::Prime: return "Prime"; case HashType::CoherentPrime: return "CoherentPrime"; case HashType::ReversedPrime: return "ReversedPrime"; case HashType::Rng: return "Rng"; case HashType::BaseConvert: return "BaseConvert"; default: throw std::runtime_error{"Invalid hash type."}; } } InterpolationType string_to_interpolation_type(const std::string& interpolation_type) { if (equals_case_insensitive(interpolation_type, "Nearest")) { return InterpolationType::Nearest; } else if (equals_case_insensitive(interpolation_type, "Linear")) { return InterpolationType::Linear; } else if (equals_case_insensitive(interpolation_type, "Smoothstep")) { return InterpolationType::Smoothstep; } throw std::runtime_error{fmt::format("Invalid interpolation type: {}", interpolation_type)}; } std::string to_string(InterpolationType interpolation_type) { switch (interpolation_type) { case InterpolationType::Nearest: return "Nearest"; case InterpolationType::Linear: return "Linear"; case InterpolationType::Smoothstep: return "Smoothstep"; default: throw std::runtime_error{"Invalid interpolation type."}; } } ReductionType string_to_reduction_type(const std::string& reduction_type) { if (equals_case_insensitive(reduction_type, "Concatenation")) { return ReductionType::Concatenation; } else if (equals_case_insensitive(reduction_type, "Sum")) { return ReductionType::Sum; } else if (equals_case_insensitive(reduction_type, "Product")) { return ReductionType::Product; } throw std::runtime_error{fmt::format("Invalid reduction type: {}", reduction_type)}; } std::string to_string(ReductionType reduction_type) { switch (reduction_type) { case ReductionType::Concatenation: return "Concatenation"; case ReductionType::Sum: return "Sum"; case ReductionType::Product: return "Product"; default: throw std::runtime_error{"Invalid reduction type."}; } } int cuda_runtime_version() { int version; CUDA_CHECK_THROW(cudaRuntimeGetVersion(&version)); return version; } int cuda_device() { int device; CUDA_CHECK_THROW(cudaGetDevice(&device)); return device; } void set_cuda_device(int device) { CUDA_CHECK_THROW(cudaSetDevice(device)); } int cuda_device_count() { int device_count; CUDA_CHECK_THROW(cudaGetDeviceCount(&device_count)); return device_count; } bool cuda_supports_virtual_memory(int device) { int supports_vmm; CU_CHECK_THROW(cuDeviceGetAttribute(&supports_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, device)); return supports_vmm != 0; } std::unordered_map& cuda_device_properties() { static auto* cuda_device_props = new std::unordered_map{}; return *cuda_device_props; } const cudaDeviceProp& cuda_get_device_properties(int device) { if (cuda_device_properties().count(device) == 0) { auto& props = cuda_device_properties()[device]; CUDA_CHECK_THROW(cudaGetDeviceProperties(&props, device)); } return cuda_device_properties().at(device); } std::string cuda_device_name(int device) { return cuda_get_device_properties(device).name; } uint32_t cuda_compute_capability(int device) { const auto& props = cuda_get_device_properties(device); return props.major * 10 + props.minor; } uint32_t cuda_max_supported_compute_capability() { int cuda_version = cuda_runtime_version(); if (cuda_version < 11000) { return 75; } else if (cuda_version < 11010) { return 80; } else if (cuda_version < 11080) { return 86; } else { return 90; } } uint32_t cuda_supported_compute_capability(int device) { return std::min(cuda_compute_capability(device), cuda_max_supported_compute_capability()); } size_t cuda_max_shmem(int device) { return cuda_get_device_properties(device).sharedMemPerBlockOptin; } uint32_t cuda_max_registers(int device) { return (uint32_t)cuda_get_device_properties(device).regsPerBlock; } size_t cuda_memory_granularity(int device) { size_t granularity; CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = 0; CUresult granularity_result = cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); if (granularity_result == CUDA_ERROR_NOT_SUPPORTED) { return 1; } CU_CHECK_THROW(granularity_result); return granularity; } MemoryInfo cuda_memory_info() { MemoryInfo info; CUDA_CHECK_THROW(cudaMemGetInfo(&info.free, &info.total)); info.used = info.total - info.free; return info; } std::string generate_device_code_preamble() { return dfmt(0, R"( #include #include using namespace tcnn; )"); } std::string to_snake_case(const std::string& str) { std::stringstream result; result << (char)std::tolower(str[0]); for (uint32_t i = 1; i < str.length(); ++i) { if (std::isupper(str[i])) { result << "_" << (char)std::tolower(str[i]); } else { result << str[i]; } } return result.str(); } std::vector split(const std::string& text, const std::string& delim) { std::vector result; size_t begin = 0; while (true) { size_t end = text.find_first_of(delim, begin); if (end == std::string::npos) { result.emplace_back(text.substr(begin)); return result; } else { result.emplace_back(text.substr(begin, end - begin)); begin = end + 1; } } return result; } std::string to_lower(std::string str) { std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::tolower(c); }); return str; } std::string to_upper(std::string str) { std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::toupper(c); }); return str; } template <> std::string type_to_string() { return "bool"; } template <> std::string type_to_string() { return "int"; } template <> std::string type_to_string() { return "char"; } template <> std::string type_to_string() { return "uint8_t"; } template <> std::string type_to_string() { return "uint16_t"; } template <> std::string type_to_string() { return "uint32_t"; } template <> std::string type_to_string() { return "double"; } template <> std::string type_to_string() { return "float"; } template <> std::string type_to_string<__half>() { return "__half"; } }