Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import atexit | |
import logging | |
import socket | |
import time | |
from typing import Optional | |
from urllib.parse import urlparse | |
import torch | |
from torch import nn | |
from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available | |
if is_requests_available(): | |
import requests | |
from requests import ConnectionError | |
if is_vllm_available(): | |
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | |
from vllm.distributed.utils import StatelessProcessGroup | |
if is_vllm_ascend_available(): | |
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator | |
logger = logging.getLogger(__name__) | |
class VLLMClient: | |
""" | |
A client class to interact with a vLLM server. | |
This class provides methods to generate completions, initialize and manage weight update groups, and update model | |
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. | |
Args: | |
base_url (`str` or `None`, *optional*, defaults to `None`): | |
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are | |
ignored. | |
host (`str`, *optional*, defaults to `"0.0.0.0"`): | |
IP address of the vLLM server. Ignored if `base_url` is provided. | |
server_port (`int`, *optional*, defaults to `8000`): | |
Port number of the vLLM server. Ignored if `base_url` is provided. | |
group_port (`int`, *optional*, defaults to `51216`): | |
Port number for the weight update group. | |
connection_timeout (`float`, *optional*, defaults to `0.0`): | |
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the | |
timeout, a `ConnectionError` is raised. | |
Examples: | |
Run the vLLM server with the model `Qwen/Qwen2.5-7B`: | |
``` | |
$ trl vllm-serve --model Qwen/Qwen2.5-7B | |
... | |
INFO: Application startup complete. | |
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) | |
``` | |
Use the client to generate completions and update model weights: | |
```python | |
>>> from trl.extras.vllm_client import VLLMClient | |
>>> client = VLLMClient() | |
>>> client.generate(["Hello, AI!", "Tell me a joke"]) | |
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], | |
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] | |
>>> from transformers import AutoModelForCausalLM | |
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") | |
>>> client.init_communicator() | |
>>> client.update_model_params(model) | |
``` | |
There are several ways to initialize the client: | |
```python | |
VLLMClient(base_url="http://localhost:8000") | |
VLLMClient(base_url="http://192.168.1.100:8000") | |
VLLMClient(host="localhost", server_port=8000) | |
VLLMClient(host="192.168.1.100", server_port=8000) | |
``` | |
""" | |
def __init__( | |
self, | |
base_url: Optional[str] = None, | |
host: str = "0.0.0.0", | |
server_port: int = 8000, | |
group_port: int = 51216, | |
connection_timeout: float = 0.0, | |
): | |
if not is_requests_available(): | |
raise ImportError("requests is not installed. Please install it with `pip install requests`.") | |
if not is_vllm_available(): | |
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") | |
self.session = requests.Session() | |
if base_url is not None: | |
# Parse the base_url to extract host and port | |
parsed_url = urlparse(base_url) | |
self.host = socket.gethostbyname(parsed_url.hostname) | |
scheme = parsed_url.scheme or "http" | |
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" | |
else: | |
self.host = host | |
self.server_port = server_port | |
self.base_url = f"http://{self.host}:{self.server_port}" | |
self.group_port = group_port | |
self.check_server(connection_timeout) # check server and fail after timeout | |
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): | |
""" | |
Check server availability with retries on failure, within a total timeout duration. If the server is not up | |
after the total timeout duration, raise a `ConnectionError`. | |
Args: | |
retry_interval (`float`, *optional*, defaults to `2.0`): | |
Interval in seconds between retries. | |
total_timeout (`float`, *optional*, defaults to `0.0`): | |
Total timeout duration in seconds. | |
""" | |
url = f"{self.base_url}/health/" | |
start_time = time.time() # Record the start time | |
while True: | |
try: | |
response = requests.get(url) | |
except requests.exceptions.RequestException as exc: | |
# Check if the total timeout duration has passed | |
elapsed_time = time.time() - start_time | |
if elapsed_time >= total_timeout: | |
raise ConnectionError( | |
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make " | |
"sure the server is running by running `trl vllm-serve`." | |
) from exc | |
else: | |
if response.status_code == 200: | |
if "X-Forwarded-For" in response.headers: | |
self.host = response.headers["X-Forwarded-For"] | |
logger.info("Server is up!") | |
return None | |
# Retry logic: wait before trying again | |
logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") | |
time.sleep(retry_interval) | |
def generate( | |
self, | |
prompts: list[str], | |
n: int = 1, | |
repetition_penalty: float = 1.0, | |
temperature: float = 1.0, | |
top_p: float = 1.0, | |
top_k: int = -1, | |
min_p: float = 0.0, | |
max_tokens: int = 16, | |
guided_decoding_regex: Optional[str] = None, | |
) -> list[list[int]]: | |
""" | |
Generates model completions for the provided prompts. | |
Args: | |
prompts (`list[str]`): | |
List of text prompts for which the model will generate completions. | |
n (`int`, *optional*, defaults to `1`): | |
Number of completions to generate for each prompt. | |
repetition_penalty (`float`, *optional*, defaults to `1.0`): | |
Parameter for repetition penalty. 1.0 means no penalty. | |
temperature (`float`, *optional*, defaults to `1.0`): | |
Temperature parameter for sampling. Higher values increase diversity. | |
top_p (`float`, *optional*, defaults to `1.0`): | |
Top-p sampling parameter.`1.0` means no truncation. | |
top_k (`int`, *optional*, defaults to `-1`): | |
Top-k sampling parameter. `-1` means no truncation. | |
min_p (`float`, *optional*, defaults to `0.0`): | |
Minimum probability for sampling. | |
max_tokens (`int`, *optional*, defaults to `16`): | |
Maximum number of tokens to generate for each prompt. | |
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): | |
Regular expression to guide the decoding process. | |
Returns: | |
`list[list[int]]`: | |
List of lists of token IDs representing the model-generated completions for each prompt. | |
""" | |
url = f"{self.base_url}/generate/" | |
response = self.session.post( | |
url, | |
json={ | |
"prompts": prompts, | |
"n": n, | |
"repetition_penalty": repetition_penalty, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"min_p": min_p, | |
"max_tokens": max_tokens, | |
"guided_decoding_regex": guided_decoding_regex, | |
}, | |
) | |
if response.status_code == 200: | |
return response.json()["completion_ids"] | |
else: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
def init_communicator(self): | |
""" | |
Initializes the weight update group in a distributed setup for model synchronization. | |
""" | |
# Get the world size from the server | |
url = f"{self.base_url}/get_world_size/" | |
response = requests.get(url) | |
if response.status_code == 200: | |
vllm_world_size = response.json()["world_size"] | |
else: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
world_size = vllm_world_size + 1 # add the client to the world | |
self.rank = vllm_world_size # the client's rank is the last process | |
# Initialize weight update group | |
url = f"{self.base_url}/init_communicator/" | |
# In the server side, the host is set to 0.0.0.0 | |
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) | |
if response.status_code != 200: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
# Brief delay to allow server initialization. While not strictly required (client socket will retry on | |
# connection failure), this prevents log warnings like: | |
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 | |
time.sleep(0.1) | |
# Set up the communication group for weight broadcasting | |
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) | |
self.pynccl_comm = PyNcclCommunicator(pg, device=0) | |
# When the client object is deleted, close the weight update group | |
atexit.register(self.close_communicator) | |
def update_named_param(self, name: str, weights: torch.Tensor): | |
""" | |
Updates a specific named parameter in the model and broadcasts it to other processes. | |
Args: | |
name (`str`): | |
Name of the layer whose weights are being updated. | |
weights (`torch.Tensor`): | |
Tensor containing the updated weights. | |
""" | |
dtype, shape = str(weights.dtype), tuple(weights.shape) | |
url = f"{self.base_url}/update_named_param/" | |
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) | |
if response.status_code != 200: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
# Broadcast the weights to the other processes | |
self.pynccl_comm.broadcast(weights, src=self.rank) | |
self.pynccl_comm.group.barrier() | |
def update_model_params(self, model: nn.Module): | |
""" | |
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. | |
Args: | |
model (`nn.Module`): | |
Model whose parameters (weights/biases) are to be updated. | |
""" | |
for name, param in model.named_parameters(): | |
# Update each parameter individually | |
self.update_named_param(name, param.data) | |
def reset_prefix_cache(self): | |
""" | |
Resets the prefix cache for the model. | |
""" | |
url = f"{self.base_url}/reset_prefix_cache/" | |
response = self.session.post(url) | |
if response.status_code != 200: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
def close_communicator(self): | |
""" | |
Closes the weight update group and cleans up the communication group. | |
""" | |
url = f"{self.base_url}/close_communicator/" | |
try: | |
response = self.session.post(url) | |
except ConnectionError: | |
# The server might be already down, so we don't need to close the communicator | |
pass | |
else: | |
if response.status_code != 200: | |
raise Exception(f"Request failed: {response.status_code}, {response.text}") | |
# Example usage | |
if __name__ == "__main__": | |
from vllm import SamplingParams | |
client = VLLMClient() | |
client.init_communicator() | |
# Generate completions | |
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) | |
print("Responses:", responses) # noqa | |
# Update model weights | |
from transformers import AutoModelForCausalLM | |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") | |
client.update_model_params(model) | |