Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import logging | |
import os | |
from collections.abc import Sequence | |
from contextlib import asynccontextmanager | |
from dataclasses import dataclass, field | |
from itertools import chain | |
from multiprocessing import Pipe, Process | |
from multiprocessing.connection import Connection | |
from typing import Optional | |
import torch | |
from trl import TrlParser | |
from trl.import_utils import ( | |
is_fastapi_available, | |
is_pydantic_available, | |
is_uvicorn_available, | |
is_vllm_ascend_available, | |
is_vllm_available, | |
) | |
if is_fastapi_available(): | |
from fastapi import FastAPI | |
if is_pydantic_available(): | |
from pydantic import BaseModel | |
if is_uvicorn_available(): | |
import uvicorn | |
if is_vllm_available(): | |
from vllm import LLM, SamplingParams | |
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | |
from vllm.distributed.parallel_state import get_world_group | |
from vllm.distributed.utils import StatelessProcessGroup | |
from vllm.sampling_params import GuidedDecodingParams | |
from vllm.utils import get_open_port | |
if is_vllm_ascend_available(): | |
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator | |
logger = logging.getLogger(__name__) | |
# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following | |
# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use | |
# the 'spawn' start method | |
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | |
class WeightSyncWorkerExtension: | |
""" | |
A vLLM worker extension that enables weight synchronization between a client and multiple server workers. | |
This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle | |
efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights | |
from a client process and distribute them to all worker processes participating in model inference. | |
""" | |
# The following attributes are initialized when `init_communicator` method is called. | |
pynccl_comm = None # Communicator for weight updates | |
client_rank = None # Source rank for broadcasting updated weights | |
def init_communicator(self, host: str, port: int, world_size: int) -> None: | |
""" | |
Initializes the weight update communicator using a stateless process group. | |
This method creates a `StatelessProcessGroup` that allows external training processes to | |
communicate with vLLM workers without interfering with the global torch distributed group. | |
Args: | |
host (`str`): | |
Hostname or IP address of the master node. | |
port (`int`): | |
Port number to be used for communication. | |
world_size (`int`): | |
Total number of participating processes in the update group. | |
""" | |
if self.pynccl_comm is not None: | |
raise RuntimeError("Weight update group already initialized. Call close_communicator first.") | |
# Get the rank of the current worker in the global world group. | |
rank = get_world_group().rank | |
# Create a stateless process group to manage communication between training processes and vLLM workers. | |
pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size) | |
# Initialize the NCCL-based communicator for weight synchronization. | |
self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) | |
# The client process that sends updated weights has the highest rank (world_size - 1). | |
self.client_rank = world_size - 1 | |
def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None: | |
""" | |
Receives updated weights from the client process and updates the named parameter in the model. | |
Args: | |
name (`str`): | |
Name of the weight tensor being updated. | |
dtype (`torch.dtype`): | |
Data type of the weight tensor (e.g., `torch.float32`). | |
shape (`Sequence[int]`): | |
Shape of the weight tensor. | |
""" | |
if self.pynccl_comm is None: | |
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.") | |
# Allocate memory for the incoming weight tensor on the correct device. | |
weight = torch.empty(shape, dtype=dtype, device=self.device) | |
# Use NCCL to broadcast the updated weights from the client (src) to all workers. | |
self.pynccl_comm.broadcast(weight, src=self.client_rank) | |
self.pynccl_comm.group.barrier() | |
# Load the received weights into the model. | |
self.model_runner.model.load_weights(weights=[(name, weight)]) | |
def close_communicator(self) -> None: | |
""" | |
Closes the communicator when weight synchronization is no longer needed. | |
This method deletes the NCCL communicator to release associated resources. | |
""" | |
if self.pynccl_comm is not None: | |
del self.pynccl_comm | |
self.pynccl_comm = None # Ensure attribute is reset to None | |
self.client_rank = None # Ensure attribute is reset to None | |
class ScriptArguments: | |
r""" | |
Arguments for the script. | |
Args: | |
model (`str`): | |
Model name or path to load the model from. | |
revision (`str` or `None`, *optional*, defaults to `None`): | |
Revision to use for the model. If not specified, the default branch will be used. | |
tensor_parallel_size (`int`, *optional*, defaults to `1`): | |
Number of tensor parallel workers to use. | |
data_parallel_size (`int`, *optional*, defaults to `1`): | |
Number of data parallel workers to use. | |
host (`str`, *optional*, defaults to `"0.0.0.0"`): | |
Host address to run the server on. | |
port (`int`, *optional*, defaults to `8000`): | |
Port to run the server on. | |
gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): | |
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the | |
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus | |
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors | |
during initialization. | |
dtype (`str`, *optional*, defaults to `"auto"`): | |
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined | |
based on the model configuration. Find the supported values in the vLLM documentation. | |
max_model_len (`int` or `None`, *optional*, defaults to `None`): | |
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced | |
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model | |
context size, which might be much larger than the KV cache, leading to inefficiencies. | |
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`): | |
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support | |
this feature. | |
enforce_eager (`bool` or `None`, *optional*, defaults to `None`): | |
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the | |
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. | |
kv_cache_dtype (`str`, *optional*, defaults to `"auto"`): | |
Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type. | |
log_level (`str`, *optional*, defaults to `"info"`): | |
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, | |
`"trace"`. | |
""" | |
model: str = field( | |
metadata={"help": "Model name or path to load the model from."}, | |
) | |
revision: Optional[str] = field( | |
default=None, | |
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, | |
) | |
tensor_parallel_size: int = field( | |
default=1, | |
metadata={"help": "Number of tensor parallel workers to use."}, | |
) | |
data_parallel_size: int = field( | |
default=1, | |
metadata={"help": "Number of data parallel workers to use."}, | |
) | |
host: str = field( | |
default="0.0.0.0", | |
metadata={"help": "Host address to run the server on."}, | |
) | |
port: int = field( | |
default=8000, | |
metadata={"help": "Port to run the server on."}, | |
) | |
gpu_memory_utilization: float = field( | |
default=0.9, | |
metadata={ | |
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " | |
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " | |
"size and thus improve the model's throughput. However, if the value is too high, it may cause " | |
"out-of-memory (OOM) errors during initialization." | |
}, | |
) | |
dtype: str = field( | |
default="auto", | |
metadata={ | |
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " | |
"determined based on the model configuration. Find the supported values in the vLLM documentation." | |
}, | |
) | |
max_model_len: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " | |
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " | |
"context size, which might be much larger than the KV cache, leading to inefficiencies." | |
}, | |
) | |
enable_prefix_caching: Optional[bool] = field( | |
default=None, | |
metadata={ | |
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " | |
"hardware support this feature." | |
}, | |
) | |
enforce_eager: Optional[bool] = field( | |
default=None, | |
metadata={ | |
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always " | |
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager " | |
"execution in hybrid." | |
}, | |
) | |
kv_cache_dtype: str = field( | |
default="auto", | |
metadata={ | |
"help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type." | |
}, | |
) | |
log_level: str = field( | |
default="info", | |
metadata={ | |
"help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', " | |
"'trace'." | |
}, | |
) | |
def llm_worker( | |
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection | |
) -> None: | |
# Set required environment variables for DP to work with vLLM | |
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) | |
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) | |
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) | |
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) | |
llm = LLM( | |
model=script_args.model, | |
revision=script_args.revision, | |
tensor_parallel_size=script_args.tensor_parallel_size, | |
gpu_memory_utilization=script_args.gpu_memory_utilization, | |
enforce_eager=script_args.enforce_eager, | |
dtype=script_args.dtype, | |
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can | |
# directly reuse the KV cache if it shares the same prefix with one of the existing queries. | |
# This is particularly useful here because we generate completions from the same prompts. | |
enable_prefix_caching=script_args.enable_prefix_caching, | |
kv_cache_dtype=script_args.kv_cache_dtype, | |
max_model_len=script_args.max_model_len, | |
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", | |
) | |
# Send ready signal to parent process | |
connection.send({"status": "ready"}) | |
while True: | |
# Wait for commands from the parent process | |
try: | |
command = connection.recv() | |
except KeyboardInterrupt: | |
llm.collective_rpc(method="close_communicator") | |
break | |
# Handle commands | |
if command["type"] in ["call", "fire_and_forget"]: | |
method_name = command["method"] | |
args, kwargs = command.get("args", ()), command.get("kwargs", {}) | |
method = getattr(llm, method_name) | |
result = method(*args, **kwargs) | |
if command["type"] == "call": | |
connection.send(result) | |
elif command["type"] == "shutdown": | |
break | |
def chunk_list(lst: list, n: int) -> list[list]: | |
""" | |
Split list `lst` into `n` evenly distributed sublists. | |
Example: | |
>>> chunk_list([1, 2, 3, 4, 5, 6], 2) | |
[[1, 2, 3], [4, 5, 6]] | |
>>> chunk_list([1, 2, 3, 4, 5, 6], 4) | |
[[1, 2], [3, 4], [5], [6]] | |
>>> chunk_list([1, 2, 3, 4, 5, 6], 8) | |
[[1], [2], [3], [4], [5], [6], [], []] | |
""" | |
k, r = divmod(len(lst), n) | |
return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)] | |
def main(script_args: ScriptArguments): | |
if not is_fastapi_available(): | |
raise ImportError( | |
"FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`." | |
) | |
if not is_pydantic_available(): | |
raise ImportError( | |
"Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`." | |
) | |
if not is_uvicorn_available(): | |
raise ImportError( | |
"Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`." | |
) | |
if not is_vllm_available(): | |
raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.") | |
# Spawn dp workers, and setup pipes for communication | |
master_port = get_open_port() | |
connections = [] | |
processes = [] | |
for data_parallel_rank in range(script_args.data_parallel_size): | |
parent_connection, child_connection = Pipe() | |
process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection)) | |
process.start() | |
connections.append(parent_connection) | |
processes.append(process) | |
async def lifespan(app: FastAPI): | |
# Wait for all workers to send "ready" | |
ready_connections = set() | |
while len(ready_connections) < script_args.data_parallel_size: | |
for connection in connections: | |
msg = connection.recv() | |
if isinstance(msg, dict) and msg.get("status") == "ready": | |
ready_connections.add(connection) | |
yield | |
# Wait for processes to terminate | |
for process in processes: | |
process.join(timeout=10) # Wait for 10 seconds for the process to terminate | |
if process.is_alive(): | |
logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...") | |
process.terminate() | |
process.join() # ensure process termination after calling terminate() | |
app = FastAPI(lifespan=lifespan) | |
# Define the endpoints for the model server | |
async def health(): | |
""" | |
Health check endpoint to verify that the server is running. | |
""" | |
return {"status": "ok"} | |
async def get_world_size(): | |
""" | |
Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`. | |
Returns: | |
`dict`: | |
A dictionary containing the world size. | |
Example response: | |
```json | |
{"world_size": 8} | |
``` | |
""" | |
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} | |
class GenerateRequest(BaseModel): | |
prompts: list[str] | |
n: int = 1 | |
repetition_penalty: float = 1.0 | |
temperature: float = 1.0 | |
top_p: float = 1.0 | |
top_k: int = -1 | |
min_p: float = 0.0 | |
max_tokens: int = 16 | |
guided_decoding_regex: Optional[str] = None | |
class GenerateResponse(BaseModel): | |
completion_ids: list[list[int]] | |
async def generate(request: GenerateRequest): | |
""" | |
Generates completions for the provided prompts. | |
Args: | |
request (`GenerateRequest`): | |
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. | |
Returns: | |
`GenerateResponse`: | |
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. | |
Example request: | |
```json | |
{"prompts": ["Hello world", "What is AI?"]} | |
``` | |
Example response: | |
```json | |
{"completion_ids": [[101, 102, 103], [201, 202, 203]]} | |
``` | |
""" | |
# Guided decoding, if enabled | |
if request.guided_decoding_regex is not None: | |
guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex) | |
else: | |
guided_decoding = None | |
# Sampling parameters | |
sampling_params = SamplingParams( | |
n=request.n, | |
repetition_penalty=request.repetition_penalty, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
min_p=request.min_p, | |
max_tokens=request.max_tokens, | |
guided_decoding=guided_decoding, | |
) | |
# Evenly distribute prompts across DP ranks | |
chunked_prompts = chunk_list(request.prompts, script_args.data_parallel_size) | |
# Send the prompts to each worker | |
for connection, prompts in zip(connections, chunked_prompts): | |
# When the number of prompts is less than data_parallel_size, some workers will receive empty prompts. | |
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply | |
# with vLLM's requirement, and we later ignore the result. | |
if not prompts: | |
prompts = ["<placeholder>"] | |
kwargs = {"prompts": prompts, "sampling_params": sampling_params} | |
connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) | |
# Receive results | |
all_outputs = [connection.recv() for connection in connections] | |
# Handle empty prompts (see above) | |
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts] | |
# Flatten and combine all results | |
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list | |
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] | |
return {"completion_ids": completion_ids} | |
class InitCommunicatorRequest(BaseModel): | |
host: str | |
port: int | |
world_size: int | |
async def init_communicator(request: InitCommunicatorRequest): | |
""" | |
Initializes the communicator for synchronizing model weights between a client and multiple server | |
workers. | |
Args: | |
request (`InitCommunicatorRequest`): | |
- `host` (`str`): Hostname or IP address of the master node. | |
- `port` (`int`): Port number to be used for communication. | |
- `world_size` (`int`): Total number of participating processes in the group. | |
""" | |
world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1 | |
# The function init_communicator is called this way: init_communicator(host, port, world_size) | |
# So with collective_rpc we need to call it this way: | |
# llm.collective_rpc(method="init_communicator", args=(host, port, world_size)) | |
kwargs = {"method": "init_communicator", "args": (request.host, request.port, world_size)} | |
for connection in connections: | |
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) | |
return {"message": "Request received, initializing communicator"} | |
class UpdateWeightsRequest(BaseModel): | |
name: str | |
dtype: str | |
shape: list[int] | |
async def update_named_param(request: UpdateWeightsRequest): | |
""" | |
Updates the model weights with the provided tensor. | |
Once this endpoint is called, the client process should broadcast the updated weights to all server workers. | |
Args: | |
request (`UpdateWeightsRequest`): | |
- `name` (`str`): Name of the weight tensor being updated. | |
- `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`). | |
- `shape` (list of `int`): Shape of the weight | |
""" | |
# The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10)) | |
# So with collective_rpc we need to call it this way: | |
# llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10))) | |
dtype = torch.__getattribute__(request.dtype.split(".")[-1]) | |
kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))} | |
for connection in connections: | |
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) | |
return {"message": "Request received, updating named parameter"} | |
async def reset_prefix_cache(): | |
""" | |
Resets the prefix cache for the model. | |
""" | |
for connection in connections: | |
connection.send({"type": "call", "method": "reset_prefix_cache"}) | |
# Wait for and collect all results | |
all_outputs = [connection.recv() for connection in connections] | |
success = all(output for output in all_outputs) | |
return {"message": "Request received, resetting prefix cache status: " + str(success)} | |
async def close_communicator(): | |
""" | |
Closes the weight update group and cleans up associated resources. | |
""" | |
kwargs = {"method": "close_communicator"} | |
for connection in connections: | |
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) | |
return {"message": "Request received, closing communicator"} | |
# Start the server | |
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) | |
def make_parser(subparsers: argparse._SubParsersAction = None): | |
if subparsers is not None: | |
parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments) | |
else: | |
parser = TrlParser(ScriptArguments) | |
return parser | |
if __name__ == "__main__": | |
parser = make_parser() | |
(script_args,) = parser.parse_args_and_config() | |
main(script_args) | |