Spaces:
Sleeping
Sleeping
File size: 7,962 Bytes
33d4721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import threading
import time
from types import SimpleNamespace
import requests
from autotrain import logger
from autotrain.backends.base import BaseBackend
NVCF_API = "https://huggingface.co/api/integrations/dgx/v1"
class NVCFRunner(BaseBackend):
"""
NVCFRunner is a backend class responsible for managing and executing NVIDIA NVCF jobs.
Methods
-------
_convert_dict_to_object(dictionary):
Recursively converts a dictionary to an object using SimpleNamespace.
_conf_nvcf(token, nvcf_type, url, job_name, method="POST", payload=None):
Configures and submits an NVCF job using the specified parameters.
_poll_nvcf(url, token, job_name, method="get", timeout=86400, interval=30, op="poll"):
Polls the status of an NVCF job until completion or timeout.
create():
Initiates the creation and polling of an NVCF job.
"""
def _convert_dict_to_object(self, dictionary):
if isinstance(dictionary, dict):
for key, value in dictionary.items():
dictionary[key] = self._convert_dict_to_object(value)
return SimpleNamespace(**dictionary)
elif isinstance(dictionary, list):
return [self._convert_dict_to_object(item) for item in dictionary]
else:
return dictionary
def _conf_nvcf(self, token, nvcf_type, url, job_name, method="POST", payload=None):
logger.info(f"{job_name}: {method} - Configuring NVCF {nvcf_type}.")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
try:
if method.upper() == "POST":
response = requests.post(url, headers=headers, json=payload, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
if response.status_code == 202:
logger.info(f"{job_name}: {method} - Successfully submitted NVCF job. Polling reqId for completion")
response_data = response.json()
nvcf_reqid = response_data.get("nvcfRequestId")
if nvcf_reqid:
logger.info(f"{job_name}: nvcfRequestId: {nvcf_reqid}")
return nvcf_reqid
logger.warning(f"{job_name}: nvcfRequestId key is missing in the response body")
return None
result = response.json()
result_obj = self._convert_dict_to_object(result)
logger.info(f"{job_name}: {method} - Successfully processed NVCF {nvcf_type}.")
return result_obj
except requests.HTTPError as http_err:
# Log the response body for more context
error_message = http_err.response.text if http_err.response else "No additional error information."
logger.error(
f"{job_name}: HTTP error occurred processing NVCF {nvcf_type} with {method} request: {http_err}. "
f"Error details: {error_message}"
)
raise Exception(f"HTTP Error {http_err.response.status_code}: {http_err}. Details: {error_message}")
except (requests.Timeout, ConnectionError) as err:
logger.error(f"{job_name}: Failed to process NVCF {nvcf_type} with {method} request - {repr(err)}")
raise Exception(f"Unreachable, please try again later: {err}")
def _poll_nvcf(self, url, token, job_name, method="get", timeout=86400, interval=30, op="poll"):
timeout = float(timeout)
interval = float(interval)
start_time = time.time()
success = False
last_full_log = ""
while time.time() - start_time < timeout:
try:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
if method.upper() == "GET":
response = requests.get(url, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
if response.status_code == 404 and success:
break
response.raise_for_status()
try:
data = response.json()
except ValueError:
logger.error("Failed to parse JSON from response")
continue
if response.status_code == 500:
logger.error("Training failed")
if "detail" in data:
detail_message = data["detail"]
for line in detail_message.split("\n"):
if line.strip():
print(line)
break
if response.status_code in [200, 202]:
logger.info(
f"{job_name}: {method} - {response.status_code} - {'Polling completed' if response.status_code == 200 else 'Polling reqId for completion'}"
)
if "log" in data:
current_full_log = data["log"]
if current_full_log != last_full_log:
new_log_content = current_full_log[len(last_full_log) :]
for line in new_log_content.split("\n"):
if line.strip():
print(line)
last_full_log = current_full_log
if response.status_code == 200:
success = True
except requests.HTTPError as http_err:
if not (http_err.response.status_code == 404 and success):
logger.error(f"HTTP error occurred: {http_err}")
except (requests.ConnectionError, ValueError) as err:
logger.error(f"Error while handling request: {err}")
time.sleep(interval)
if not success:
raise TimeoutError(f"Operation '{op}' did not complete successfully within the timeout period.")
def create(self):
hf_token = self.env_vars["HF_TOKEN"]
job_name = f"{self.username}-{self.params.project_name}"
logger.info("Starting NVCF training")
logger.info(f"job_name: {job_name}")
logger.info(f"backend: {self.backend}")
nvcf_url_submit = f"{NVCF_API}/invoke/{self.available_hardware[self.backend]['id']}"
org_name = os.environ.get("SPACE_ID")
if org_name is None:
raise ValueError("SPACE_ID environment variable is not set")
org_name = org_name.split("/")[0]
nvcf_fr_payload = {
"cmd": [
"conda",
"run",
"--no-capture-output",
"-p",
"/app/env",
"python",
"-u",
"-m",
"uvicorn",
"autotrain.app.training_api:api",
"--host",
"0.0.0.0",
"--port",
"7860",
],
"env": {key: value for key, value in self.env_vars.items()},
"ORG_NAME": org_name,
}
nvcf_fn_req = self._conf_nvcf(
token=hf_token,
nvcf_type="job_submit",
url=nvcf_url_submit,
job_name=job_name,
method="POST",
payload=nvcf_fr_payload,
)
nvcf_url_reqpoll = f"{NVCF_API}/status/{nvcf_fn_req}"
logger.info(f"{job_name}: Polling : {nvcf_url_reqpoll}")
poll_thread = threading.Thread(
target=self._poll_nvcf,
kwargs={
"url": nvcf_url_reqpoll,
"token": hf_token,
"job_name": job_name,
"method": "GET",
"timeout": 172800,
"interval": 20,
},
)
poll_thread.start()
return nvcf_fn_req
|