/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** @file multi_stream.h * @author Thomas Müller, NVIDIA * @brief Helper class for parallelizing workload across multiple streams. */ #pragma once #include #include namespace tcnn { void free_multi_streams(cudaStream_t parent_stream); // Synchronization helpers struct StreamAndEvent { public: StreamAndEvent() { CUDA_CHECK_THROW(cudaStreamCreate(&m_stream)); CUDA_CHECK_THROW(cudaEventCreate(&m_event)); } ~StreamAndEvent() { if (m_stream) { free_multi_streams(m_stream); free_gpu_memory_arena(m_stream); cudaStreamDestroy(m_stream); } if (m_event) { cudaEventDestroy(m_event); } } // Only allow moving of these guys. No copying. StreamAndEvent& operator=(const StreamAndEvent&) = delete; StreamAndEvent(const StreamAndEvent&) = delete; StreamAndEvent& operator=(StreamAndEvent&& other) { std::swap(m_stream, other.m_stream); std::swap(m_event, other.m_event); return *this; } StreamAndEvent(StreamAndEvent&& other) { *this = std::move(other); } void wait_for(cudaEvent_t event) { CUDA_CHECK_THROW(cudaStreamWaitEvent(m_stream, event, 0)); } void wait_for(cudaStream_t stream) { CUDA_CHECK_THROW(cudaEventRecord(m_event, stream)); wait_for(m_event); } void signal(cudaStream_t stream) { CUDA_CHECK_THROW(cudaEventRecord(m_event, m_stream)); CUDA_CHECK_THROW(cudaStreamWaitEvent(stream, m_event, 0)); } cudaStream_t get() { return m_stream; } private: cudaStream_t m_stream = {}; cudaEvent_t m_event = {}; }; struct MultiStream { public: MultiStream() { CUDA_CHECK_THROW(cudaEventCreate(&m_event)); } ~MultiStream() { cudaEventDestroy(m_event); } MultiStream& operator=(const MultiStream&) = delete; MultiStream(const MultiStream&) = delete; MultiStream& operator=(MultiStream&&) = delete; MultiStream(MultiStream&&) = delete; void signal(cudaStream_t outer_stream) { for (size_t i = 0; i < m_n_streams; ++i) { m_streams[i].signal(outer_stream); } } void wait_for(cudaStream_t stream) { if (m_n_streams == 0) { return; } CUDA_CHECK_THROW(cudaEventRecord(m_event, stream)); for (size_t i = 0; i < m_n_streams; ++i) { m_streams[i].wait_for(m_event); } } void resize(size_t n_streams) { if (n_streams > m_streams.size()) { m_streams.resize(n_streams); } m_n_streams = n_streams; } cudaStream_t get(size_t idx) { if (idx >= m_n_streams) { throw std::runtime_error{fmt::format("MultiStream: invalid stream index requested: {}/{}", idx, m_n_streams)}; } return m_streams.at(idx).get(); } private: std::vector m_streams; // May be less than m_streams.size()! // The user may only need to sync fewer than that. size_t m_n_streams = 0; cudaEvent_t m_event; }; inline std::unordered_map>>& stream_multi_streams() { static auto* stream_multi_streams = new std::unordered_map>>{}; return *stream_multi_streams; } inline std::unordered_map>>& global_multi_streams() { static auto* global_multi_streams = new std::unordered_map>>{}; return *global_multi_streams; } inline std::stack>& get_multi_stream_stack(cudaStream_t parent_stream) { return parent_stream ? stream_multi_streams()[parent_stream] : global_multi_streams()[cuda_device()]; } inline void free_multi_streams(cudaStream_t parent_stream) { CHECK_THROW(parent_stream); // Copy the multi stream shared_ptr's into a separate variable, // such that their destruction happens after unordered_map::erase(...) // is already finished. This alleviates potential non-reentrancy problems. auto multi_streams = stream_multi_streams()[parent_stream]; stream_multi_streams().erase(parent_stream); } inline std::shared_ptr reserve_multi_stream(cudaStream_t parent_stream, size_t n_streams) { auto& stack = get_multi_stream_stack(parent_stream); if (stack.empty()) { stack.push(std::make_shared()); } auto result = stack.top(); stack.pop(); result->resize(n_streams); return result; } inline void return_multi_stream(cudaStream_t parent_stream, std::shared_ptr multi_stream) { if (parent_stream ? (stream_multi_streams().count(parent_stream) == 0) : (global_multi_streams().count(cuda_device()) == 0)) { throw std::runtime_error{"Attempted to return multi stream to the wrong parent stream."}; } auto& stack = get_multi_stream_stack(parent_stream); stack.push(multi_stream); } // RAII wrapper around MultiStream struct SyncedMultiStream { public: SyncedMultiStream() = default; SyncedMultiStream(cudaStream_t stream, size_t n_streams) : m_main_stream{stream}, m_n_streams{n_streams} { if (m_n_streams == 0) { throw std::runtime_error{"SyncedMultiStream: must request at least one stream"}; } else if (m_n_streams == 1) { return; } m_multi_stream = reserve_multi_stream(m_main_stream, m_n_streams-1); m_multi_stream->wait_for(m_main_stream); } ~SyncedMultiStream() { if (m_multi_stream) { m_multi_stream->signal(m_main_stream); return_multi_stream(m_main_stream, m_multi_stream); } } // Only allow moving of these guys. No copying. SyncedMultiStream& operator=(const SyncedMultiStream& other) = delete; SyncedMultiStream(const SyncedMultiStream&) = delete; SyncedMultiStream& operator=(SyncedMultiStream&& other) { std::swap(m_multi_stream, other.m_multi_stream); std::swap(m_main_stream, other.m_main_stream); std::swap(m_n_streams, other.m_n_streams); return *this; } SyncedMultiStream(SyncedMultiStream&& other) { *this = std::move(other); } cudaStream_t get(size_t idx) { if (m_n_streams == 0) { throw std::runtime_error{"SyncedMultiStream: must have at least one stream"}; } if (idx == 0) { return m_main_stream; } else { if (!m_multi_stream) { throw std::runtime_error{"SyncedMultiStream: invalid multistream"}; } return m_multi_stream->get(idx-1); } } private: std::shared_ptr m_multi_stream = nullptr; cudaStream_t m_main_stream = nullptr; size_t m_n_streams = 0; }; }