Spaces:
Sleeping
Sleeping
import os | |
from dataclasses import dataclass | |
from typing import Optional | |
import requests | |
from autotrain import logger | |
AUTOTRAIN_API = os.environ.get("AUTOTRAIN_API", "https://autotrain-projects-autotrain-advanced.hf.space/") | |
BACKENDS = { | |
"spaces-a10g-large": "a10g-large", | |
"spaces-a10g-small": "a10g-small", | |
"spaces-a100-large": "a100-large", | |
"spaces-t4-medium": "t4-medium", | |
"spaces-t4-small": "t4-small", | |
"spaces-cpu-upgrade": "cpu-upgrade", | |
"spaces-cpu-basic": "cpu-basic", | |
"spaces-l4x1": "l4x1", | |
"spaces-l4x4": "l4x4", | |
"spaces-l40sx1": "l40sx1", | |
"spaces-l40sx4": "l40sx4", | |
"spaces-l40sx8": "l40sx8", | |
"spaces-a10g-largex2": "a10g-largex2", | |
"spaces-a10g-largex4": "a10g-largex4", | |
} | |
PARAMS = {} | |
PARAMS["llm"] = { | |
"target_modules": "all-linear", | |
"log": "tensorboard", | |
"mixed_precision": "fp16", | |
"quantization": "int4", | |
"peft": True, | |
"block_size": 1024, | |
"epochs": 3, | |
"padding": "right", | |
"chat_template": "none", | |
"max_completion_length": 128, | |
"distributed_backend": "ddp", | |
"scheduler": "linear", | |
"merge_adapter": True, | |
} | |
PARAMS["text-classification"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["st"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["image-classification"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["image-object-detection"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["seq2seq"] = { | |
"mixed_precision": "fp16", | |
"target_modules": "all-linear", | |
"log": "tensorboard", | |
} | |
PARAMS["tabular"] = { | |
"categorical_imputer": "most_frequent", | |
"numerical_imputer": "median", | |
"numeric_scaler": "robust", | |
} | |
PARAMS["token-classification"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["text-regression"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["image-regression"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
} | |
PARAMS["vlm"] = { | |
"mixed_precision": "fp16", | |
"target_modules": "all-linear", | |
"log": "tensorboard", | |
"quantization": "int4", | |
"peft": True, | |
"epochs": 3, | |
} | |
PARAMS["extractive-qa"] = { | |
"mixed_precision": "fp16", | |
"log": "tensorboard", | |
"max_seq_length": 512, | |
"max_doc_stride": 128, | |
} | |
DEFAULT_COLUMN_MAPPING = {} | |
DEFAULT_COLUMN_MAPPING["llm:sft"] = {"text_column": "text"} | |
DEFAULT_COLUMN_MAPPING["llm:generic"] = {"text_column": "text"} | |
DEFAULT_COLUMN_MAPPING["llm:default"] = {"text_column": "text"} | |
DEFAULT_COLUMN_MAPPING["llm:dpo"] = { | |
"prompt_column": "prompt", | |
"text_column": "chosen", | |
"rejected_text_column": "rejected", | |
} | |
DEFAULT_COLUMN_MAPPING["llm:orpo"] = { | |
"prompt_column": "prompt", | |
"text_column": "chosen", | |
"rejected_text_column": "rejected", | |
} | |
DEFAULT_COLUMN_MAPPING["llm:reward"] = {"text_column": "chosen", "rejected_text_column": "rejected"} | |
DEFAULT_COLUMN_MAPPING["vlm:captioning"] = {"image_column": "image", "text_column": "caption"} | |
DEFAULT_COLUMN_MAPPING["vlm:vqa"] = { | |
"image_column": "image", | |
"prompt_text_column": "question", | |
"text_column": "answer", | |
} | |
DEFAULT_COLUMN_MAPPING["st:pair"] = {"sentence1": "anchor", "sentence2": "positive"} | |
DEFAULT_COLUMN_MAPPING["st:pair_class"] = { | |
"sentence1_column": "premise", | |
"sentence2_column": "hypothesis", | |
"target_column": "label", | |
} | |
DEFAULT_COLUMN_MAPPING["st:pair_score"] = { | |
"sentence1_column": "sentence1", | |
"sentence2_column": "sentence2", | |
"target_column": "score", | |
} | |
DEFAULT_COLUMN_MAPPING["st:triplet"] = { | |
"sentence1_column": "anchor", | |
"sentence2_column": "positive", | |
"sentence3_column": "negative", | |
} | |
DEFAULT_COLUMN_MAPPING["st:qa"] = {"sentence1_column": "query", "sentence2_column": "answer"} | |
DEFAULT_COLUMN_MAPPING["text-classification"] = {"text_column": "text", "target_column": "target"} | |
DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text_column": "text", "target_column": "target"} | |
DEFAULT_COLUMN_MAPPING["text-regression"] = {"text_column": "text", "target_column": "target"} | |
DEFAULT_COLUMN_MAPPING["token-classification"] = {"text_column": "tokens", "target_column": "tags"} | |
DEFAULT_COLUMN_MAPPING["image-classification"] = {"image_column": "image", "target_column": "label"} | |
DEFAULT_COLUMN_MAPPING["image-regression"] = {"image_column": "image", "target_column": "target"} | |
DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image_column": "image", "objects_column": "objects"} | |
DEFAULT_COLUMN_MAPPING["tabular:classification"] = {"id_column": "id", "target__columns": ["target"]} | |
DEFAULT_COLUMN_MAPPING["tabular:regression"] = {"id_column": "id", "target_columns": ["target"]} | |
DEFAULT_COLUMN_MAPPING["extractive-qa"] = { | |
"text_column": "context", | |
"question_column": "question", | |
"answer_column": "answers", | |
} | |
VALID_TASKS = [k for k in DEFAULT_COLUMN_MAPPING.keys()] | |
class Client: | |
""" | |
A client to interact with the AutoTrain API. | |
Attributes: | |
host (Optional[str]): The host URL for the AutoTrain API. | |
token (Optional[str]): The authentication token for the API. | |
username (Optional[str]): The username for the API. | |
Methods: | |
__post_init__(): | |
Initializes the client with default values if not provided and sets up headers. | |
__str__(): | |
Returns a string representation of the client with masked token. | |
__repr__(): | |
Returns a string representation of the client with masked token. | |
create(project_name: str, task: str, base_model: str, hardware: str, dataset: str, train_split: str, column_mapping: Optional[dict] = None, params: Optional[dict] = None, valid_split: Optional[str] = None): | |
Creates a new project on the AutoTrain platform. | |
get_logs(job_id: str): | |
Retrieves logs for a given job ID. | |
stop_training(job_id: str): | |
Stops the training for a given job ID. | |
""" | |
host: Optional[str] = None | |
token: Optional[str] = None | |
username: Optional[str] = None | |
def __post_init__(self): | |
if self.host is None: | |
self.host = AUTOTRAIN_API | |
if self.token is None: | |
self.token = os.environ.get("HF_TOKEN") | |
if self.username is None: | |
self.username = os.environ.get("HF_USERNAME") | |
if self.token is None or self.username is None: | |
raise ValueError("Please provide a valid username and token") | |
self.headers = {"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"} | |
def __str__(self): | |
return f"Client(host={self.host}, token=****, username={self.username})" | |
def __repr__(self): | |
return self.__str__() | |
def create( | |
self, | |
project_name: str, | |
task: str, | |
base_model: str, | |
backend: str, | |
dataset: str, | |
train_split: str, | |
column_mapping: Optional[dict] = None, | |
params: Optional[dict] = None, | |
valid_split: Optional[str] = None, | |
): | |
if task not in VALID_TASKS: | |
raise ValueError(f"Invalid task. Valid tasks are: {VALID_TASKS}") | |
if backend not in BACKENDS: | |
raise ValueError(f"Invalid backend. Valid backends are: {list(BACKENDS.keys())}") | |
url = f"{self.host}/api/create_project" | |
if task == "llm:defaut": | |
task = "llm:generic" | |
if params is None: | |
params = {} | |
if task.startswith("llm"): | |
params = {k: v for k, v in PARAMS["llm"].items() if k not in params} | |
elif task.startswith("st"): | |
params = {k: v for k, v in PARAMS["st"].items() if k not in params} | |
else: | |
params = {k: v for k, v in PARAMS[task].items() if k not in params} | |
if column_mapping is None: | |
column_mapping = DEFAULT_COLUMN_MAPPING[task] | |
# check if column_mapping is valid for the task | |
default_col_map = DEFAULT_COLUMN_MAPPING[task] | |
missing_cols = [] | |
for k, _ in default_col_map.items(): | |
if k not in column_mapping.keys(): | |
missing_cols.append(k) | |
if missing_cols: | |
raise ValueError(f"Missing columns in column_mapping: {missing_cols}") | |
data = { | |
"project_name": project_name, | |
"task": task, | |
"base_model": base_model, | |
"hardware": backend, | |
"params": params, | |
"username": self.username, | |
"column_mapping": column_mapping, | |
"hub_dataset": dataset, | |
"train_split": train_split, | |
"valid_split": valid_split, | |
} | |
response = requests.post(url, headers=self.headers, json=data) | |
if response.status_code == 200: | |
resp = response.json() | |
logger.info( | |
f"Project created successfully. Job ID: {resp['job_id']}. View logs at: https://hf.co/spaces/{resp['job_id']}" | |
) | |
return resp | |
else: | |
logger.error(f"Error creating project: {response.json()}") | |
return response.json() | |
def get_logs(self, job_id: str): | |
url = f"{self.host}/api/logs" | |
data = {"jid": job_id} | |
response = requests.post(url, headers=self.headers, json=data) | |
return response.json() | |
def stop_training(self, job_id: str): | |
url = f"{self.host}/api/stop_training/{job_id}" | |
data = {"jid": job_id} | |
response = requests.post(url, headers=self.headers, json=data) | |
return response.json() | |