Spaces:
Build error
Build error
/* | |
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: Apache-2.0 | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
/** @file cuda_graph.h | |
* @author Thomas Müller, NVIDIA | |
* @brief Implementation of a CUDA graph capture/update with subsequent execution | |
*/ | |
namespace tcnn { | |
class CudaGraph; | |
inline std::deque<CudaGraph*>& current_captures() { | |
static thread_local std::deque<CudaGraph*> s_current_captures; | |
return s_current_captures; | |
} | |
inline CudaGraph* current_capture() { | |
return current_captures().empty() ? nullptr : current_captures().front(); | |
} | |
class CudaGraph { | |
public: | |
~CudaGraph() { | |
try { | |
reset(); | |
} catch (const std::runtime_error& error) { | |
// Don't need to report on destruction problems when the driver is shutting down. | |
if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { | |
log_warning("Could not destroy cuda graph: {}", error.what()); | |
} | |
} | |
} | |
ScopeGuard capture_guard(cudaStream_t stream) { | |
// Can't capture on the global stream | |
if (stream == nullptr || stream == cudaStreamLegacy) { | |
return {}; | |
} | |
// If the caller is already capturing, no need for a nested capture. | |
cudaStreamCaptureStatus capture_status; | |
CUDA_CHECK_THROW(cudaStreamIsCapturing(stream, &capture_status)); | |
if (capture_status != cudaStreamCaptureStatusNone) { | |
return {}; | |
} | |
cudaError_t capture_result = cudaStreamIsCapturing(cudaStreamLegacy, &capture_status); | |
if (capture_result == cudaErrorStreamCaptureImplicit) { | |
return {}; | |
} | |
CUDA_CHECK_THROW(capture_result); | |
if (capture_status != cudaStreamCaptureStatusNone) { | |
return {}; | |
} | |
// Start capturing | |
if (m_graph) { | |
CUDA_CHECK_THROW(cudaGraphDestroy(m_graph)); | |
m_graph = nullptr; | |
} | |
CUDA_CHECK_THROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); | |
current_captures().push_back(this); | |
// Stop capturing again once the returned object goes out of scope | |
return ScopeGuard{[this, stream]() { | |
CUDA_CHECK_THROW(cudaStreamEndCapture(stream, &m_graph)); | |
if (current_captures().back() != this) { | |
throw std::runtime_error{"CudaGraph: must end captures in reverse order of creation."}; | |
} | |
current_captures().pop_back(); | |
if (m_synchronize_when_capture_done) { | |
CUDA_CHECK_THROW(cudaDeviceSynchronize()); | |
m_synchronize_when_capture_done = false; | |
} | |
// Capture failed for some reason. Reset state and don't execute anything. | |
// A corresponding exception is likely already in flight. | |
if (!m_graph) { | |
if (m_graph_instance) { | |
CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); | |
} | |
m_graph = nullptr; | |
m_graph_instance = nullptr; | |
return; | |
} | |
// If we previously created a graph instance, try to update it with the newly captured graph. | |
// This is cheaper than creating a new instance from scratch (and may involve just updating | |
// pointers rather than changing the topology of the graph.) | |
if (m_graph_instance) { | |
cudaGraphExecUpdateResultInfo update_result; | |
CUDA_CHECK_THROW(cudaGraphExecUpdate(m_graph_instance, m_graph, &update_result)); | |
// If the update failed, reset graph instance. We will create a new one next. | |
if (update_result.result != cudaGraphExecUpdateSuccess) { | |
CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); | |
m_graph_instance = nullptr; | |
} | |
cudaGraphExecUpdateResult update_result; | |
cudaGraphNode_t error_node; | |
CUDA_CHECK_THROW(cudaGraphExecUpdate(m_graph_instance, m_graph, &error_node, &update_result)); | |
// If the update failed, reset graph instance. We will create a new one next. | |
if (update_result != cudaGraphExecUpdateSuccess) { | |
CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); | |
m_graph_instance = nullptr; | |
} | |
} | |
if (!m_graph_instance) { | |
CUDA_CHECK_THROW(cudaGraphInstantiate(&m_graph_instance, m_graph, NULL, NULL, 0)); | |
} | |
CUDA_CHECK_THROW(cudaGraphLaunch(m_graph_instance, stream)); | |
}}; | |
} | |
void reset() { | |
if (m_graph) { | |
CUDA_CHECK_THROW(cudaGraphDestroy(m_graph)); | |
m_graph = nullptr; | |
} | |
if (m_graph_instance) { | |
CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); | |
m_graph_instance = nullptr; | |
} | |
} | |
void schedule_synchronize() { | |
m_synchronize_when_capture_done = true; | |
} | |
private: | |
cudaGraph_t m_graph = nullptr; | |
cudaGraphExec_t m_graph_instance = nullptr; | |
bool m_synchronize_when_capture_done = false; | |
}; | |
} | |