gen3c / gui /api /multi_gpu.py
elungky's picture
Initial commit for new Space - pre-built Docker image
28451f7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import os
import signal
from loguru import logger as log
from v2v_utils import move_to_device, clone_tensors
TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost'
TORCHRUN_DEFAULT_MASTER_PORT = 12355
def _get_inference_class(cosmos_variant: str):
if cosmos_variant == 'predict1':
from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel
from cosmos_predict1.utils.distributed import is_rank0
return Gen3cPersistentModel, is_rank0
else:
raise ValueError(f"Unsupported cosmos variant: {cosmos_variant}")
def _inference_worker(rank: int, args: argparse.Namespace,
gpu_count: int,
cosmos_variant: str,
input_queues: 'list[torch.multiprocessing.Queue]',
result_queue: 'torch.multiprocessing.Queue',
attrs_queue: 'torch.multiprocessing.Queue'):
"""
One such function will run, in a separate process, for each GPU.
Each process loads the model and keeps it in memory.
"""
log.debug(f'inference_worker for rank {rank} starting, doing imports now')
import torch
import torch.distributed as dist
InferenceAR, is_tp_cp_pp_rank0 = _get_inference_class(cosmos_variant)
log.debug(f'inference_worker for rank {rank} done with imports.')
# The FQDN of the host that is running worker with rank 0; used to initialize the Torch Distributed backend.
os.environ.setdefault("MASTER_ADDR", TORCHRUN_DEFAULT_MASTER_ADDR)
# The port on the MASTER_ADDR that can be used to host the C10d TCP store.
os.environ.setdefault("MASTER_PORT", str(TORCHRUN_DEFAULT_MASTER_PORT))
# The local rank.
os.environ["LOCAL_RANK"] = str(rank)
# The global rank.
os.environ["RANK"] = str(rank)
# The rank of the worker group. A number between 0 and max_nnodes. When running a single worker group per node, this is the rank of the node.
os.environ["GROUP_RANK"] = str(rank)
# The rank of the worker across all the workers that have the same role. The role of the worker is specified in the WorkerSpec.
os.environ["ROLE_RANK"] = str(rank)
# The local world size (e.g. number of workers running locally); equals to --nproc-per-node specified on torchrun.
os.environ["LOCAL_WORLD_SIZE"] = str(gpu_count)
# The world size (total number of workers in the job).
os.environ["WORLD_SIZE"] = str(gpu_count)
# The total number of workers that was launched with the same role specified in WorkerSpec.
os.environ["ROLE_WORLD_SIZE"] = str(gpu_count)
# # The number of worker group restarts so far.
# os.environ["TORCHELASTIC_RESTART_COUNT"] = TODO
# # The configured maximum number of restarts.
# os.environ["TORCHELASTIC_MAX_RESTARTS"] = TODO
# # Equal to the rendezvous run_id (e.g. unique job id).
# os.environ["TORCHELASTIC_RUN_ID"] = TODO
# # System executable override. If provided, the python user script will use the value of PYTHON_EXEC as executable. The sys.executable is used by default.
# os.environ["PYTHON_EXEC"] = TODO
# We're already parallelizing over the context, so we can't also parallelize inside the tokenizers (?)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = f"cuda:{rank}"
torch.cuda.set_device(rank)
input_queue = input_queues[rank]
del input_queues
# Load model once
log.debug(f'inference_worker for rank {rank} creating the model object now')
local_model = InferenceAR(args)
del args
log.debug(f'inference_worker for rank {rank} ready, pushing a "ready" message to the queue')
result_queue.put((rank, "ready"))
# Install interrupt signal handler so that we can shut down gracefully.
should_quit = False
def signal_handler(signum, frame):
nonlocal should_quit
log.info(f"[RANK{rank}] Received signal {signum}, shutting down")
should_quit = True
try:
input_queue.put(None)
except ValueError:
pass
signal.signal(signal.SIGINT, signal_handler)
while not should_quit:
try:
inputs_task = input_queue.get()
except ValueError:
# Queue was closed, we can exit.
log.debug(f"[RANK{rank}] Input queue was closed, exiting.")
break
if inputs_task is None:
# Special sentinel value to indicate that we are done and can exit.
log.debug(f"[RANK{rank}] Got input {inputs_task}, exiting.")
break
# Note: we don't need to chunk the inputs for this rank / process, this is done
# automatically in the model.
# Note: we don't need to move the inputs to a specific device either since the
# Gen3C API expects NumPy arrays.
if False:
log.debug(f"[RANK{rank}] Moving task to {device=}")
inputs_task = move_to_device(inputs_task, device)
# Run the requested task
with torch.no_grad():
task_type, args, kwargs = inputs_task
log.debug(f"[RANK{rank}] Got task: {task_type=}")
if task_type == 'inference':
log.debug(f"[RANK{rank}] Running `inference_on_cameras()`...")
output = local_model.inference_on_cameras(*args, **kwargs)
log.debug(f"[RANK{rank}] Done `inference_on_cameras()`!")
if is_tp_cp_pp_rank0():
log.debug(f"[RANK{rank}] Moving outputs of `inference_on_cameras()` to the CPU")
output = move_to_device(output, device='cpu')
log.debug(f"[RANK{rank}] Pushing outputs of `inference_on_cameras()` to the results queue")
result_queue.put(output)
elif task_type == 'seeding':
log.debug(f"[RANK{rank}] Calling `seed_model_from_values()...`")
if cosmos_variant == 'predict1':
output = local_model.seed_model_from_values(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported cosmos variant: {cosmos_variant}")
output = move_to_device(output, device='cpu')
result_queue.put((rank, "seed_model_from_values_done", output))
log.debug(f"[RANK{rank}] Done with `seed_model_from_values()`")
elif task_type == 'clear_cache':
log.debug(f"[RANK{rank}] Calling `clear_cache()...`")
local_model.clear_cache()
result_queue.put((rank, "clear_cache_done"))
log.debug(f"[RANK{rank}] Done with `clear_cache()`")
elif task_type == 'get_cache_input_depths':
log.debug(f"[RANK{rank}] Calling `get_cache_input_depths()...`")
input_depths = local_model.get_cache_input_depths()
attrs_queue.put(('cache_input_depths', input_depths.cpu(), True))
log.debug(f"[RANK{rank}] Done with `get_cache_input_depths()`")
elif task_type == 'getattr':
assert kwargs is None
assert len(args) == 1
attr_name = args[0]
assert isinstance(attr_name, str)
has_attr = hasattr(local_model, attr_name)
attr_value_or_none = getattr(local_model, attr_name)
if has_attr and (attr_value_or_none is not None) and torch.is_tensor(attr_value_or_none):
log.debug(f"[RANK{rank}] Attribute {attr_name=} is a torch tensor on "
f"device {attr_value_or_none.device}, cloning it before sending it through the queue")
attr_value_or_none = attr_value_or_none.clone()
log.debug(f"[RANK{rank}] Pushing attribute value for {attr_name=}")
attrs_queue.put((attr_name, attr_value_or_none, has_attr))
else:
raise NotImplementedError(f"Unsupported task type for Cosmos inference worker: {task_type}")
# Cleanup before exiting
local_model.cleanup()
del local_model
def inference_worker(*args, **kwargs):
try:
_inference_worker(*args, **kwargs)
except Exception as e:
import traceback
rank = os.environ.get("LOCAL_RANK", "(unknown)")
log.error(f"[RANK{rank}] encountered exception: {e}. Will re-raise after cleanup."
f" Stack trace:\n{traceback.format_exc()}")
try:
import torch.distributed as dist
dist.destroy_process_group()
log.info(f"[RANK{rank}] Destroyed model parallel group after catching exception."
" Will re-raise now.")
except Exception as _:
pass
raise e
class MultiGPUInferenceAR():
"""
Adapter class to run multi-GPU Cosmos inference in the context of the FastAPI inference server.
This class implements the same interface as `InferenceAR`, but spawns one process per GPU and
forwards inference requests to the multiple processes via a work queue.
The worker processes wait for work from the queue, perform inference, and gather all results
on the rank 0 process. That process then pushes results to the result queue.
"""
def __init__(self, gpu_count: int, cosmos_variant: str, args: argparse.Namespace):
import torch
import torch.multiprocessing as mp
self.gpu_count = gpu_count
assert self.gpu_count <= torch.cuda.device_count(), \
f"Requested {self.gpu_count} GPUs, but only {torch.cuda.device_count()} are available."
ctx = mp.get_context('spawn')
manager = ctx.Manager()
self.input_queues: list[mp.Queue] = [ctx.Queue() for _ in range(self.gpu_count)]
self.result_queue = manager.Queue()
self.attrs_queue = manager.Queue()
log.info(f"Spawning {self.gpu_count} processes (one per GPU)")
self.ctx = mp.spawn(
inference_worker,
args=(args, self.gpu_count, cosmos_variant,
self.input_queues, self.result_queue, self.attrs_queue),
nprocs=self.gpu_count,
join=False
)
log.info(f"Waiting for {self.gpu_count} processes to load the model...")
for _ in range(self.gpu_count):
v = self.result_queue.get()
if not isinstance(v, tuple) or len(v) != 2 or v[1] != "ready":
raise ValueError(f"Expected a 'ready' message from each process, but received: {v}")
log.info(f"Process {v[0]} is ready.")
def inference_on_cameras(self, *args, **kwargs):
log.debug(f"inference_on_cameras(): submitting request to {len(self.input_queues)} inference processes.")
for iq in self.input_queues:
# Send the same input to each process
task = ('inference', args, kwargs)
iq.put(task)
# Wait on the result queue to produce the result (this could take a while).
log.debug(f"inference_on_cameras(): waiting for result...")
outputs = self.result_queue.get()
log.debug(f"inference_on_cameras(): got inference results! Cloning and returning.")
return clone_tensors(outputs)
def seed_model_from_values(self, *args, **kwargs):
log.debug(f"seed_model_from_values(): submitting request to {len(self.input_queues)} inference processes.")
for iq in self.input_queues:
task = ('seeding', args, kwargs)
iq.put(task)
# TODO: refactor this, and maybe use some events or another primitive
log.info(f"Waiting for {self.gpu_count} processes to be done with seeding...")
for i in range(self.gpu_count):
v = self.result_queue.get()
if not isinstance(v, tuple) or len(v) != 3 or v[1] != "seed_model_from_values_done":
raise ValueError(f"Expected a 'seed_model_from_values_done' message from each process, but received: {v}")
log.info(f"Process {v[0]} is done with `seed_model_from_values()`.")
# Arbitrarily pick the output from the first process
if i == 0:
outputs = v[2]
return clone_tensors(outputs)
def clear_cache(self):
for iq in self.input_queues:
task = ('clear_cache', None, None)
iq.put(task)
# TODO: refactor this, and maybe use some events or another primitive
log.info(f"Waiting for {self.gpu_count} processes to be done with clear_cache...")
for _ in range(self.gpu_count):
v = self.result_queue.get()
if not isinstance(v, tuple) or len(v) != 2 or v[1] != "clear_cache_done":
raise ValueError(f"Expected a 'clear_cache_done' message from each process, but received: {v}")
log.info(f"Process {v[0]} is done with `clear_cache()`.")
def get_cache_input_depths(self):
name = 'cache_input_depths'
task = ('get_cache_input_depths', None, None)
self.input_queues[0].put(task)
# TODO: refactor this, and maybe use some events or another primitive
looked_up_name, value, exists = self.attrs_queue.get()
if looked_up_name != name:
# TODO: this could be handled better (retry or enforce some ordering maybe).
raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}',"
" there was likely a race condition.")
log.debug(f"Got a valid response, returning value for `get_cache_input_depths()`")
return value
def __getattr__(self, name: str):
log.debug(f"__getattr__({name=}) called")
# Note: this will not be called for methods we implement here, or attributes
# that actually exist in this object.
# Query the attribute from rank 0 (arbitrarily)
task = ('getattr', (name,), None)
self.input_queues[0].put(task)
# Get result (blocking)
log.debug(f"Waiting for response on `attrs_queue`...")
looked_up_name, value, exists = self.attrs_queue.get()
if looked_up_name != name:
# TODO: this could be handled better (retry or enforce some ordering maybe).
raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}',"
" there was likely a race condition.")
if not exists:
raise AttributeError(f"Model has no attribute named '{name}'")
log.debug(f"Got a valid response, returning {name} == {value}")
return value
def cleanup(self):
"""
Clean up resources before shutting down.
"""
log.info(f"MultiGPUInferenceAR winding down, asking {len(self.input_queues)} processes to clean up.")
# "Close" all queues (there's no actual `close` method in PyTorch MP queues)
for iq in self.input_queues:
iq.put(None)
# Wait for all processes to finish
log.info(f"Waiting for {len(self.input_queues)} processes to finish (join).")
self.ctx.join()
log.info(f"{len(self.input_queues)} processes have finished.")