import os import requests from typing import Optional import uvicorn import subprocess from subprocess import Popen from fastapi import FastAPI, Header, HTTPException, BackgroundTasks from fastapi.responses import FileResponse from huggingface_hub.hf_api import HfApi from models import config, WebhookPayload app = FastAPI() WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") @app.get("/") async def home(): return FileResponse("home.html") @app.post("/webhook") async def post_webhook( payload: WebhookPayload, task_queue: BackgroundTasks, x_webhook_secret: Optional[str] = Header(default=None), ): # if x_webhook_secret is None: # raise HTTPException(401) # if x_webhook_secret != WEBHOOK_SECRET: # raise HTTPException(403) # if not ( # payload.event.action == "update" # and payload.event.scope.startswith("repo.content") # and payload.repo.name == config.input_dataset # and payload.repo.type == "dataset" # ): # # no-op # return {"processed": False} schedule_retrain(payload=payload) # task_queue.add_task( # schedule_retrain, # payload # ) return {"processed": True} def schedule_retrain(payload: WebhookPayload): # Create the autotrain project try: result = Popen(['autotrain', '--config', 'config.yaml']) # project = AutoTrain.create_project(payload) # AutoTrain.add_data(project_id=project["id"]) # AutoTrain.start_processing(project_id=project["id"]) except requests.HTTPError as err: print("ERROR while requesting AutoTrain API:") print(f" code: {err.response.status_code}") print(f" {err.response.json()}") raise # Notify in the community tab notify_success('vicuna') print(result.returncode) return {"processed": True} class AutoTrain: @staticmethod def create_project(payload: WebhookPayload) -> dict: project_resp = requests.post( f"{AUTOTRAIN_API_URL}/api/create_project", json={ "username": config.target_namespace, "proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}", "task": 'llm:sft', # image-multi-class-classification "base_model": "meta-llama/Meta-Llama-3-8B-Instruct", "train_split": "train", "column_mapping": { "text_column": "text", }, "params": { "block_size": 1024, "model_max_length": 4096, "max_prompt_length": 512, "epochs": 1, "batch_size": 2, "lr": 0.00003, "peft": True, "quantization": "int4", "target_modules": "all-linear", "padding": "right", "optimizer": "adamw_torch", "scheduler": "linear", "gradient_accumulation": 4, "mixed_precision": "fp16", "chat_template": "chatml" } }, headers={ "Authorization": f"Bearer {HF_ACCESS_TOKEN}" } ) project_resp.raise_for_status() return project_resp.json() @staticmethod def add_data(project_id:int): requests.post( f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset", json={ "dataset_id": config.input_dataset, "dataset_split": "train", "split": 4, "col_mapping": { "image": "image", "label": "target", } }, headers={ "Authorization": f"Bearer {HF_ACCESS_TOKEN}", } ).raise_for_status() @staticmethod def start_processing(project_id: int): resp = requests.post( f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing", headers={ "Authorization": f"Bearer {HF_ACCESS_TOKEN}", } ) resp.raise_for_status() return resp def notify_success(project_id: int): message = NOTIFICATION_TEMPLATE.format( input_model=config.input_model, input_dataset=config.input_dataset, project_id=project_id, ui_url=AUTOTRAIN_UI_URL, ) return HfApi(token=HF_ACCESS_TOKEN).create_discussion( repo_id=config.input_dataset, repo_type="dataset", title="✨ Retraining started!", description=message, token=HF_ACCESS_TOKEN, ) NOTIFICATION_TEMPLATE = """\ 🌸 Hello there! Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job. (This is an automated message) """ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)