trl-sandbox / trl /extras /vllm_client.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import logging
import socket
import time
from typing import Optional
from urllib.parse import urlparse
import torch
from torch import nn
from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available
if is_requests_available():
import requests
from requests import ConnectionError
if is_vllm_available():
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
if is_vllm_ascend_available():
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
logger = logging.getLogger(__name__)
class VLLMClient:
"""
A client class to interact with a vLLM server.
This class provides methods to generate completions, initialize and manage weight update groups, and update model
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
Args:
base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
ignored.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server. Ignored if `base_url` is provided.
server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server. Ignored if `base_url` is provided.
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group.
connection_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
Examples:
Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
```
$ trl vllm-serve --model Qwen/Qwen2.5-7B
...
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```
Use the client to generate completions and update model weights:
```python
>>> from trl.extras.vllm_client import VLLMClient
>>> client = VLLMClient()
>>> client.generate(["Hello, AI!", "Tell me a joke"])
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
>>> client.init_communicator()
>>> client.update_model_params(model)
```
There are several ways to initialize the client:
```python
VLLMClient(base_url="http://localhost:8000")
VLLMClient(base_url="http://192.168.1.100:8000")
VLLMClient(host="localhost", server_port=8000)
VLLMClient(host="192.168.1.100", server_port=8000)
```
"""
def __init__(
self,
base_url: Optional[str] = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0,
):
if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
if not is_vllm_available():
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
self.session = requests.Session()
if base_url is not None:
# Parse the base_url to extract host and port
parsed_url = urlparse(base_url)
self.host = socket.gethostbyname(parsed_url.hostname)
scheme = parsed_url.scheme or "http"
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
else:
self.host = host
self.server_port = server_port
self.base_url = f"http://{self.host}:{self.server_port}"
self.group_port = group_port
self.check_server(connection_timeout) # check server and fail after timeout
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
"""
Check server availability with retries on failure, within a total timeout duration. If the server is not up
after the total timeout duration, raise a `ConnectionError`.
Args:
retry_interval (`float`, *optional*, defaults to `2.0`):
Interval in seconds between retries.
total_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds.
"""
url = f"{self.base_url}/health/"
start_time = time.time() # Record the start time
while True:
try:
response = requests.get(url)
except requests.exceptions.RequestException as exc:
# Check if the total timeout duration has passed
elapsed_time = time.time() - start_time
if elapsed_time >= total_timeout:
raise ConnectionError(
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
"sure the server is running by running `trl vllm-serve`."
) from exc
else:
if response.status_code == 200:
if "X-Forwarded-For" in response.headers:
self.host = response.headers["X-Forwarded-For"]
logger.info("Server is up!")
return None
# Retry logic: wait before trying again
logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...")
time.sleep(retry_interval)
def generate(
self,
prompts: list[str],
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
) -> list[list[int]]:
"""
Generates model completions for the provided prompts.
Args:
prompts (`list[str]`):
List of text prompts for which the model will generate completions.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each prompt.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Parameter for repetition penalty. 1.0 means no penalty.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature parameter for sampling. Higher values increase diversity.
top_p (`float`, *optional*, defaults to `1.0`):
Top-p sampling parameter.`1.0` means no truncation.
top_k (`int`, *optional*, defaults to `-1`):
Top-k sampling parameter. `-1` means no truncation.
min_p (`float`, *optional*, defaults to `0.0`):
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regular expression to guide the decoding process.
Returns:
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
url = f"{self.base_url}/generate/"
response = self.session.post(
url,
json={
"prompts": prompts,
"n": n,
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"guided_decoding_regex": guided_decoding_regex,
},
)
if response.status_code == 200:
return response.json()["completion_ids"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
def init_communicator(self):
"""
Initializes the weight update group in a distributed setup for model synchronization.
"""
# Get the world size from the server
url = f"{self.base_url}/get_world_size/"
response = requests.get(url)
if response.status_code == 200:
vllm_world_size = response.json()["world_size"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
world_size = vllm_world_size + 1 # add the client to the world
self.rank = vllm_world_size # the client's rank is the last process
# Initialize weight update group
url = f"{self.base_url}/init_communicator/"
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
# Brief delay to allow server initialization. While not strictly required (client socket will retry on
# connection failure), this prevents log warnings like:
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
time.sleep(0.1)
# Set up the communication group for weight broadcasting
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
self.pynccl_comm = PyNcclCommunicator(pg, device=0)
# When the client object is deleted, close the weight update group
atexit.register(self.close_communicator)
def update_named_param(self, name: str, weights: torch.Tensor):
"""
Updates a specific named parameter in the model and broadcasts it to other processes.
Args:
name (`str`):
Name of the layer whose weights are being updated.
weights (`torch.Tensor`):
Tensor containing the updated weights.
"""
dtype, shape = str(weights.dtype), tuple(weights.shape)
url = f"{self.base_url}/update_named_param/"
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
# Broadcast the weights to the other processes
self.pynccl_comm.broadcast(weights, src=self.rank)
self.pynccl_comm.group.barrier()
def update_model_params(self, model: nn.Module):
"""
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
Args:
model (`nn.Module`):
Model whose parameters (weights/biases) are to be updated.
"""
for name, param in model.named_parameters():
# Update each parameter individually
self.update_named_param(name, param.data)
def reset_prefix_cache(self):
"""
Resets the prefix cache for the model.
"""
url = f"{self.base_url}/reset_prefix_cache/"
response = self.session.post(url)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
def close_communicator(self):
"""
Closes the weight update group and cleans up the communication group.
"""
url = f"{self.base_url}/close_communicator/"
try:
response = self.session.post(url)
except ConnectionError:
# The server might be already down, so we don't need to close the communicator
pass
else:
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
# Example usage
if __name__ == "__main__":
from vllm import SamplingParams
client = VLLMClient()
client.init_communicator()
# Generate completions
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams())
print("Responses:", responses) # noqa
# Update model weights
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda")
client.update_model_params(model)