File size: 16,342 Bytes
28451f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/** @file   common_host.h
 *  @author Thomas Müller and Nikolaus Binder, NVIDIA
 *  @brief  Common utilities that are needed by pretty much every component of this framework.
 */

#pragma once

#include <tiny-cuda-nn/common.h>

#include <fmt/format.h>

#include <array>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

namespace tcnn {

using namespace fmt::literals;

enum class LogSeverity {
	Info,
	Debug,
	Warning,
	Error,
	Success,
};

const std::function<void(LogSeverity, const std::string&)>& log_callback();
void set_log_callback(const std::function<void(LogSeverity, const std::string&)>& callback);

template <typename... Ts>
void log(LogSeverity severity, const std::string& msg, Ts&&... args) {
	log_callback()(severity, fmt::format(msg, std::forward<Ts>(args)...));
}

template <typename... Ts> void log_info(const std::string& msg, Ts&&... args) { log(LogSeverity::Info, msg, std::forward<Ts>(args)...); }
template <typename... Ts> void log_debug(const std::string& msg, Ts&&... args) { log(LogSeverity::Debug, msg, std::forward<Ts>(args)...); }
template <typename... Ts> void log_warning(const std::string& msg, Ts&&... args) { log(LogSeverity::Warning, msg, std::forward<Ts>(args)...); }
template <typename... Ts> void log_error(const std::string& msg, Ts&&... args) { log(LogSeverity::Error, msg, std::forward<Ts>(args)...); }
template <typename... Ts> void log_success(const std::string& msg, Ts&&... args) { log(LogSeverity::Success, msg, std::forward<Ts>(args)...); }

bool verbose();
void set_verbose(bool verbose);

#define CHECK_THROW(x) \
	do { if (!(x)) throw std::runtime_error{FILE_LINE " check failed: " #x}; } while(0)

/// Checks the result of a cuXXXXXX call and throws an error on failure
#define CU_CHECK_THROW(x) \
	do { \
		CUresult _result = x; \
		if (_result != CUDA_SUCCESS) { \
			const char *msg; \
			cuGetErrorName(_result, &msg); \
			throw std::runtime_error{fmt::format(FILE_LINE " " #x " failed: {}", msg)}; \
		} \
	} while(0)

/// Checks the result of a cuXXXXXX call and prints an error on failure
#define CU_CHECK_PRINT(x) \
	do { \
		CUresult _result = x; \
		if (_result != CUDA_SUCCESS) { \
			const char *msg; \
			cuGetErrorName(_result, &msg); \
			log_error(FILE_LINE " " #x " failed: {}", msg); \
		} \
	} while(0)

/// Checks the result of a cudaXXXXXX call and throws an error on failure
#define CUDA_CHECK_THROW(x) \
	do { \
		cudaError_t _result = x; \
		if (_result != cudaSuccess) \
			throw std::runtime_error{fmt::format(FILE_LINE " " #x " failed: {}", cudaGetErrorString(_result))}; \
	} while(0)

/// Checks the result of a cudaXXXXXX call and prints an error on failure
#define CUDA_CHECK_PRINT(x) \
	do { \
		cudaError_t _result = x; \
		if (_result != cudaSuccess) \
			log_error(FILE_LINE " " #x " failed: {}", cudaGetErrorString(_result)); \
	} while(0)

/// Checks the result of optixXXXXXX call and throws an error on failure
#define OPTIX_CHECK_THROW(x)                                                                                 \
	do {                                                                                                     \
		OptixResult _result = x;                                                                                 \
		if (_result != OPTIX_SUCCESS) {                                                                          \
			throw std::runtime_error(std::string("Optix call '" #x "' failed."));                            \
		}                                                                                                    \
	} while(0)

/// Checks the result of a optixXXXXXX call and throws an error with a log message on failure
#define OPTIX_CHECK_THROW_LOG(x)                                                                                                                          \
	do {                                                                                                                                                  \
		OptixResult _result = x;                                                                                                                              \
		const size_t sizeof_log_returned = sizeof_log;                                                                                                    \
		sizeof_log = sizeof( log ); /* reset sizeof_log for future calls */                                                                               \
		if (_result != OPTIX_SUCCESS) {                                                                                                                       \
			throw std::runtime_error(std::string("Optix call '" #x "' failed. Log:\n") + log + (sizeof_log_returned == sizeof_log ? "" : "<truncated>")); \
		}                                                                                                                                                 \
	} while(0)

//////////////////////////////
// Enum<->string conversion //
//////////////////////////////

Activation string_to_activation(const std::string& activation_name);
std::string to_string(Activation activation);

GridType string_to_grid_type(const std::string& grid_type);
std::string to_string(GridType grid_type);

HashType string_to_hash_type(const std::string& hash_type);
std::string to_string(HashType hash_type);

InterpolationType string_to_interpolation_type(const std::string& interpolation_type);
std::string to_string(InterpolationType interpolation_type);

ReductionType string_to_reduction_type(const std::string& reduction_type);
std::string to_string(ReductionType reduction_type);

//////////////////
// Misc helpers //
//////////////////

int cuda_runtime_version();
inline std::string cuda_runtime_version_string() {
	int v = cuda_runtime_version();
	return fmt::format("{}.{}", v / 1000, (v % 100) / 10);
}

int cuda_device();
void set_cuda_device(int device);
int cuda_device_count();

bool cuda_supports_virtual_memory(int device);
inline bool cuda_supports_virtual_memory() {
	return cuda_supports_virtual_memory(cuda_device());
}

std::string cuda_device_name(int device);
inline std::string cuda_device_name() {
	return cuda_device_name(cuda_device());
}

uint32_t cuda_compute_capability(int device);
inline uint32_t cuda_compute_capability() {
	return cuda_compute_capability(cuda_device());
}

uint32_t cuda_max_supported_compute_capability();
uint32_t cuda_supported_compute_capability(int device);
inline uint32_t cuda_supported_compute_capability() {
	return cuda_supported_compute_capability(cuda_device());
}

size_t cuda_max_shmem(int device);
inline size_t cuda_max_shmem() {
	return cuda_max_shmem(cuda_device());
}

uint32_t cuda_max_registers(int device);
inline uint32_t cuda_max_registers() {
	return cuda_max_registers(cuda_device());
}

size_t cuda_memory_granularity(int device);
inline size_t cuda_memory_granularity() {
	return cuda_memory_granularity(cuda_device());
}

struct MemoryInfo {
	size_t total;
	size_t free;
	size_t used;
};

MemoryInfo cuda_memory_info();

// Hash helpers taken from https://stackoverflow.com/a/50978188
template <typename T>
T xorshift(T n, int i) {
	return n ^ (n >> i);
}

inline uint32_t distribute(uint32_t n) {
	uint32_t p = 0x55555555ul; // pattern of alternating 0 and 1
	uint32_t c = 3423571495ul; // random uneven integer constant;
	return c * xorshift(p * xorshift(n, 16), 16);
}

inline uint64_t distribute(uint64_t n) {
	uint64_t p = 0x5555555555555555ull; // pattern of alternating 0 and 1
	uint64_t c = 17316035218449499591ull;// random uneven integer constant;
	return c * xorshift(p * xorshift(n, 32), 32);
}

template <typename T, typename S>
constexpr typename std::enable_if<std::is_unsigned<T>::value, T>::type rotl(const T n, const S i) {
	const T m = (std::numeric_limits<T>::digits - 1);
	const T c = i & m;
	return (n << c) | (n >> (((T)0 - c) & m)); // this is usually recognized by the compiler to mean rotation
}

template <typename T>
size_t hash_combine(std::size_t seed, const T& v) {
	return rotl(seed, std::numeric_limits<size_t>::digits / 3) ^ distribute(std::hash<T>{}(v));
}

std::string generate_device_code_preamble();

std::string to_snake_case(const std::string& str);

std::vector<std::string> split(const std::string& text, const std::string& delim);

template <typename T>
std::string join(const T& components, const std::string& delim) {
	std::ostringstream s;
	for (const auto& component : components) {
		if (&components[0] != &component) {
			s << delim;
		}
		s << component;
	}

	return s.str();
}

template <typename... Ts>
std::string dfmt(uint32_t indent, const std::string& format, Ts&&... args) {
	// Trim empty lines at the beginning and end of format string.
	// Also re-indent the format string `indent` deep.
	uint32_t input_indent = std::numeric_limits<uint32_t>::max();
	uint32_t n_empty_leading = 0, n_empty_trailing = 0;
	bool leading = true;

	std::vector<std::string> lines = split(format, "\n");
	for (const auto& line : lines) {
		bool empty = true;
		uint32_t line_indent = 0;
		for (uint32_t i = 0; i < line.length(); ++i) {
			if (empty && line[i] == '\t') {
				line_indent = i+1;
			} else {
				empty = false;
				break;
			}
		}

		if (empty) {
			if (leading) {
				++n_empty_leading;
			}
			++n_empty_trailing;
			continue;
		}

		n_empty_trailing = 0;

		leading = false;
		input_indent = std::min(input_indent, line_indent);
	}

	if (input_indent == std::numeric_limits<uint32_t>::max()) {
		return "";
	}

	lines.erase(lines.end() - n_empty_trailing, lines.end());
	lines.erase(lines.begin(), lines.begin() + n_empty_leading);

	for (auto& line : lines) {
		if (line.length() >= input_indent) {
			line = line.substr(input_indent);
			line = line.insert(0, indent, '\t');
		}
	}

	return fmt::format(join(lines, "\n"), std::forward<Ts>(args)...);
}

std::string to_lower(std::string str);
std::string to_upper(std::string str);
inline bool equals_case_insensitive(const std::string& str1, const std::string& str2) {
	return to_lower(str1) == to_lower(str2);
}

struct CaseInsensitiveHash { size_t operator()(const std::string& v) const { return std::hash<std::string>{}(to_lower(v)); }};
struct CaseInsensitiveEqual { bool operator()(const std::string& l, const std::string& r) const { return equals_case_insensitive(l, r); }};

template <typename T>
using ci_hashmap = std::unordered_map<std::string, T, CaseInsensitiveHash, CaseInsensitiveEqual>;


template <typename T>
std::string type_to_string();

template <typename T, uint32_t N, size_t A>
std::string to_string(const tvec<T, N, A>& v) {
	return fmt::format("tvec<{}, {}, {}>({})", type_to_string<T>(), N, A, join(v, ", "));
}

inline std::string bytes_to_string(size_t bytes) {
	std::array<std::string, 7> suffixes = {{ "B", "KB", "MB", "GB", "TB", "PB", "EB" }};

	double count = (double)bytes;
	uint32_t i = 0;
	for (; i < suffixes.size() && count >= 1024; ++i) {
		count /= 1024;
	}

	std::ostringstream oss;
	oss.precision(3);
	oss << count << " " << suffixes[i];
	return oss.str();
}

inline bool is_pot(uint32_t num, uint32_t* log2 = nullptr) {
	if (log2) *log2 = 0;
	if (num > 0) {
		while (num % 2 == 0) {
			num /= 2;
			if (log2) ++*log2;
		}
		if (num == 1) {
			return true;
		}
	}

	return false;
}

inline uint32_t powi(uint32_t base, uint32_t exponent) {
	uint32_t result = 1;
	for (uint32_t i = 0; i < exponent; ++i) {
		result *= base;
	}

	return result;
}

class ScopeGuard {
public:
	ScopeGuard() = default;
	ScopeGuard(const std::function<void()>& callback) : m_callback{callback} {}
	ScopeGuard(std::function<void()>&& callback) : m_callback{std::move(callback)} {}
	ScopeGuard& operator=(const ScopeGuard& other) = delete;
	ScopeGuard(const ScopeGuard& other) = delete;
	ScopeGuard& operator=(ScopeGuard&& other) { std::swap(m_callback, other.m_callback); return *this; }
	ScopeGuard(ScopeGuard&& other) { *this = std::move(other); }
	~ScopeGuard() { if (m_callback) { m_callback(); } }

	void disarm() {
		m_callback = {};
	}
private:
	std::function<void()> m_callback;
};

template <typename T>
class Lazy {
public:
	template <typename F>
	T& get(F&& generator) {
		if (!m_val) {
			m_val = generator();
		}

		return m_val;
	}

private:
	T m_val;
};

#if defined(__CUDACC__) || (defined(__clang__) && defined(__CUDA__))
template <typename K, typename T, typename ... Types>
inline void linear_kernel(K kernel, uint32_t shmem_size, cudaStream_t stream, T n_elements, Types ... args) {
	if (n_elements <= 0) {
		return;
	}
	kernel<<<n_blocks_linear(n_elements), N_THREADS_LINEAR, shmem_size, stream>>>(n_elements, args...);
}

template <typename F>
__global__ void parallel_for_kernel(const size_t n_elements, F fun) {
	const size_t i = threadIdx.x + blockIdx.x * blockDim.x;
	if (i >= n_elements) return;

	fun(i);
}

template <typename F>
inline void parallel_for_gpu(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, F&& fun) {
	if (n_elements <= 0) {
		return;
	}
	parallel_for_kernel<F><<<n_blocks_linear(n_elements), N_THREADS_LINEAR, shmem_size, stream>>>(n_elements, fun);
}

template <typename F>
inline void parallel_for_gpu(cudaStream_t stream, size_t n_elements, F&& fun) {
	parallel_for_gpu(0, stream, n_elements, std::forward<F>(fun));
}

template <typename F>
inline void parallel_for_gpu(size_t n_elements, F&& fun) {
	parallel_for_gpu(nullptr, n_elements, std::forward<F>(fun));
}

template <typename F>
__global__ void parallel_for_aos_kernel(const size_t n_elements, const uint32_t n_dims, F fun) {
	const size_t dim = threadIdx.x;
	const size_t elem = threadIdx.y + blockIdx.x * blockDim.y;
	if (dim >= n_dims) return;
	if (elem >= n_elements) return;

	fun(elem, dim);
}

template <typename F>
inline void parallel_for_gpu_aos(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) {
	if (n_elements <= 0 || n_dims <= 0) {
		return;
	}

	const dim3 threads = { n_dims, div_round_up(N_THREADS_LINEAR, n_dims), 1 };
	const size_t n_threads = threads.x * threads.y;
	const dim3 blocks = { (uint32_t)div_round_up(n_elements * n_dims, n_threads), 1, 1 };

	parallel_for_aos_kernel<<<blocks, threads, shmem_size, stream>>>(
		n_elements, n_dims, fun
	);
}

template <typename F>
inline void parallel_for_gpu_aos(cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) {
	parallel_for_gpu_aos(0, stream, n_elements, n_dims, std::forward<F>(fun));
}

template <typename F>
inline void parallel_for_gpu_aos(size_t n_elements, uint32_t n_dims, F&& fun) {
	parallel_for_gpu_aos(nullptr, n_elements, n_dims, std::forward<F>(fun));
}

template <typename F>
__global__ void parallel_for_soa_kernel(const size_t n_elements, const uint32_t n_dims, F fun) {
	const size_t elem = threadIdx.x + blockIdx.x * blockDim.x;
	const size_t dim = blockIdx.y;
	if (elem >= n_elements) return;
	if (dim >= n_dims) return;

	fun(elem, dim);
}

template <typename F>
inline void parallel_for_gpu_soa(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) {
	if (n_elements <= 0 || n_dims <= 0) {
		return;
	}

	const dim3 blocks = { n_blocks_linear(n_elements), n_dims, 1 };

	parallel_for_soa_kernel<<<n_blocks_linear(n_elements), N_THREADS_LINEAR, shmem_size, stream>>>(
		n_elements, n_dims, fun
	);
}

template <typename F>
inline void parallel_for_gpu_soa(cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) {
	parallel_for_gpu_soa(0, stream, n_elements, n_dims, std::forward<F>(fun));
}

template <typename F>
inline void parallel_for_gpu_soa(size_t n_elements, uint32_t n_dims, F&& fun) {
	parallel_for_gpu_soa(nullptr, n_elements, n_dims, std::forward<F>(fun));
}
#endif

}