hardiktiwari's picture
Upload 244 files
33d4721 verified
import os
from dataclasses import dataclass
from typing import Optional
import requests
from autotrain import logger
AUTOTRAIN_API = os.environ.get("AUTOTRAIN_API", "https://autotrain-projects-autotrain-advanced.hf.space/")
BACKENDS = {
"spaces-a10g-large": "a10g-large",
"spaces-a10g-small": "a10g-small",
"spaces-a100-large": "a100-large",
"spaces-t4-medium": "t4-medium",
"spaces-t4-small": "t4-small",
"spaces-cpu-upgrade": "cpu-upgrade",
"spaces-cpu-basic": "cpu-basic",
"spaces-l4x1": "l4x1",
"spaces-l4x4": "l4x4",
"spaces-l40sx1": "l40sx1",
"spaces-l40sx4": "l40sx4",
"spaces-l40sx8": "l40sx8",
"spaces-a10g-largex2": "a10g-largex2",
"spaces-a10g-largex4": "a10g-largex4",
}
PARAMS = {}
PARAMS["llm"] = {
"target_modules": "all-linear",
"log": "tensorboard",
"mixed_precision": "fp16",
"quantization": "int4",
"peft": True,
"block_size": 1024,
"epochs": 3,
"padding": "right",
"chat_template": "none",
"max_completion_length": 128,
"distributed_backend": "ddp",
"scheduler": "linear",
"merge_adapter": True,
}
PARAMS["text-classification"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["st"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["image-classification"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["image-object-detection"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["seq2seq"] = {
"mixed_precision": "fp16",
"target_modules": "all-linear",
"log": "tensorboard",
}
PARAMS["tabular"] = {
"categorical_imputer": "most_frequent",
"numerical_imputer": "median",
"numeric_scaler": "robust",
}
PARAMS["token-classification"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["text-regression"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["image-regression"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
}
PARAMS["vlm"] = {
"mixed_precision": "fp16",
"target_modules": "all-linear",
"log": "tensorboard",
"quantization": "int4",
"peft": True,
"epochs": 3,
}
PARAMS["extractive-qa"] = {
"mixed_precision": "fp16",
"log": "tensorboard",
"max_seq_length": 512,
"max_doc_stride": 128,
}
DEFAULT_COLUMN_MAPPING = {}
DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:default"] = {"text_column": "text"}
DEFAULT_COLUMN_MAPPING["llm:dpo"] = {
"prompt_column": "prompt",
"text_column": "chosen",
"rejected_text_column": "rejected",
}
DEFAULT_COLUMN_MAPPING["llm:orpo"] = {
"prompt_column": "prompt",
"text_column": "chosen",
"rejected_text_column": "rejected",
}
DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text_column": "chosen", "rejected_text_column": "rejected"}
DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image_column": "image", "text_column": "caption"}
DEFAULT_COLUMN_MAPPING["vlm:vqa"] = {
"image_column": "image",
"prompt_text_column": "question",
"text_column": "answer",
}
DEFAULT_COLUMN_MAPPING["st:pair"] = {"sentence1": "anchor", "sentence2": "positive"}
DEFAULT_COLUMN_MAPPING["st:pair_class"] = {
"sentence1_column": "premise",
"sentence2_column": "hypothesis",
"target_column": "label",
}
DEFAULT_COLUMN_MAPPING["st:pair_score"] = {
"sentence1_column": "sentence1",
"sentence2_column": "sentence2",
"target_column": "score",
}
DEFAULT_COLUMN_MAPPING["st:triplet"] = {
"sentence1_column": "anchor",
"sentence2_column": "positive",
"sentence3_column": "negative",
}
DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1_column": "query", "sentence2_column": "answer"}
DEFAULT_COLUMN_MAPPING["text-classification"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["text-regression"] = {"text_column": "text", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["token-classification"] = {"text_column": "tokens", "target_column": "tags"}
DEFAULT_COLUMN_MAPPING["image-classification"] = {"image_column": "image", "target_column": "label"}
DEFAULT_COLUMN_MAPPING["image-regression"] = {"image_column": "image", "target_column": "target"}
DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image_column": "image", "objects_column": "objects"}
DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id_column": "id", "target__columns": ["target"]}
DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id_column": "id", "target_columns": ["target"]}
DEFAULT_COLUMN_MAPPING["extractive-qa"] = {
"text_column": "context",
"question_column": "question",
"answer_column": "answers",
}
VALID_TASKS = [k for k in DEFAULT_COLUMN_MAPPING.keys()]
@dataclass
class Client:
"""
A client to interact with the AutoTrain API.
Attributes:
host (Optional[str]): The host URL for the AutoTrain API.
token (Optional[str]): The authentication token for the API.
username (Optional[str]): The username for the API.
Methods:
__post_init__():
Initializes the client with default values if not provided and sets up headers.
__str__():
Returns a string representation of the client with masked token.
__repr__():
Returns a string representation of the client with masked token.
create(project_name: str, task: str, base_model: str, hardware: str, dataset: str, train_split: str, column_mapping: Optional[dict] = None, params: Optional[dict] = None, valid_split: Optional[str] = None):
Creates a new project on the AutoTrain platform.
get_logs(job_id: str):
Retrieves logs for a given job ID.
stop_training(job_id: str):
Stops the training for a given job ID.
"""
host: Optional[str] = None
token: Optional[str] = None
username: Optional[str] = None
def __post_init__(self):
if self.host is None:
self.host = AUTOTRAIN_API
if self.token is None:
self.token = os.environ.get("HF_TOKEN")
if self.username is None:
self.username = os.environ.get("HF_USERNAME")
if self.token is None or self.username is None:
raise ValueError("Please provide a valid username and token")
self.headers = {"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}
def __str__(self):
return f"Client(host={self.host}, token=****, username={self.username})"
def __repr__(self):
return self.__str__()
def create(
self,
project_name: str,
task: str,
base_model: str,
backend: str,
dataset: str,
train_split: str,
column_mapping: Optional[dict] = None,
params: Optional[dict] = None,
valid_split: Optional[str] = None,
):
if task not in VALID_TASKS:
raise ValueError(f"Invalid task. Valid tasks are: {VALID_TASKS}")
if backend not in BACKENDS:
raise ValueError(f"Invalid backend. Valid backends are: {list(BACKENDS.keys())}")
url = f"{self.host}/api/create_project"
if task == "llm:defaut":
task = "llm:generic"
if params is None:
params = {}
if task.startswith("llm"):
params = {k: v for k, v in PARAMS["llm"].items() if k not in params}
elif task.startswith("st"):
params = {k: v for k, v in PARAMS["st"].items() if k not in params}
else:
params = {k: v for k, v in PARAMS[task].items() if k not in params}
if column_mapping is None:
column_mapping = DEFAULT_COLUMN_MAPPING[task]
# check if column_mapping is valid for the task
default_col_map = DEFAULT_COLUMN_MAPPING[task]
missing_cols = []
for k, _ in default_col_map.items():
if k not in column_mapping.keys():
missing_cols.append(k)
if missing_cols:
raise ValueError(f"Missing columns in column_mapping: {missing_cols}")
data = {
"project_name": project_name,
"task": task,
"base_model": base_model,
"hardware": backend,
"params": params,
"username": self.username,
"column_mapping": column_mapping,
"hub_dataset": dataset,
"train_split": train_split,
"valid_split": valid_split,
}
response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 200:
resp = response.json()
logger.info(
f"Project created successfully. Job ID: {resp['job_id']}. View logs at: https://hf.co/spaces/{resp['job_id']}"
)
return resp
else:
logger.error(f"Error creating project: {response.json()}")
return response.json()
def get_logs(self, job_id: str):
url = f"{self.host}/api/logs"
data = {"jid": job_id}
response = requests.post(url, headers=self.headers, json=data)
return response.json()
def stop_training(self, job_id: str):
url = f"{self.host}/api/stop_training/{job_id}"
data = {"jid": job_id}
response = requests.post(url, headers=self.headers, json=data)
return response.json()