import os import requests from typing import Optional import uvicorn from subprocess import Popen import yaml import datetime from fastapi import FastAPI, Header, BackgroundTasks from fastapi.responses import FileResponse from huggingface_hub.hf_api import HfApi from src.models import config, WebhookPayload app = FastAPI() WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") @app.get("/") async def home(): return FileResponse("home.html") @app.post("/webhook") async def post_webhook( payload: WebhookPayload, task_queue: BackgroundTasks, x_webhook_secret: Optional[str] = Header(default=None), ): # if x_webhook_secret is None: # raise HTTPException(401) # if x_webhook_secret != WEBHOOK_SECRET: # raise HTTPException(403) # if not ( # payload.event.action == "update" # and payload.event.scope.startswith("repo.content") # and payload.repo.name == config.input_dataset # and payload.repo.type == "dataset" # ): # # no-op # return {"processed": False} # schedule_retrain(payload=payload) task_queue.add_task( schedule_retrain, payload ) return {"processed": True} def schedule_retrain(payload: WebhookPayload): # Create the autotrain project id = str(int(datetime.datetime.now().timestamp())) try: yaml_path = os.path.join(os.getcwd(), "src/config.yaml") with open(yaml_path) as f: list_doc = yaml.safe_load(f) list_doc['project_name'] = id with open(yaml_path, "w") as f: yaml.dump(list_doc, f, default_flow_style=False) result = Popen(['autotrain', '--config', yaml_path]) result.wait() # project = AutoTrain.create_project(payload) # AutoTrain.add_data(project_id=project["id"]) # AutoTrain.start_processing(project_id=project["id"]) except requests.HTTPError as err: print("ERROR while requesting AutoTrain API:") print(f" code: {err.response.status_code}") print(f" {err.response.json()}") raise # Notify in the community tab notify_success(id) deploy_model(id="1726082187") print(result.returncode) return {"processed": True} def notify_success(project_id: str): message = NOTIFICATION_TEMPLATE.format( input_model=config.input_model, input_dataset=config.input_dataset, project_id=project_id, ) return HfApi(token=HF_ACCESS_TOKEN).create_discussion( repo_id=config.input_dataset, repo_type="dataset", title="✨ Retraining started!", description=message, token=HF_ACCESS_TOKEN, ) def notify_url(url: str): message = URL_TEMPLATE.format( url=url, ) return HfApi(token=HF_ACCESS_TOKEN).create_discussion( repo_id='Platma/platma-retrain', repo_type="space", title="✨ Endpoint is ready!", description=message, token=HF_ACCESS_TOKEN, ) def deploy_model(id: str): api = HfApi(token=HF_ACCESS_TOKEN) url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma" data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4", "scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout": 15}}, "model": {"framework": "pytorch", "image": { "custom": {"health_route": "/health", "url": "ghcr.io/huggingface/text-generation-inference:sha-f852190", "env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "2048", "MAX_TOTAL_TOKENS": "2512", "MODEL_ID": "/repository"}}}, "repository": f"Platma/{id}", "secrets": {}, "task": "text-generation"}, "name": f"platma-{id}", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"} headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}", "Content-Type": "application/json"} r = requests.post(url, json=data, headers=headers) print(r) r = api.get_inference_endpoint(name=f"platma-{id}") while True: print("Fetching url") if r.status == 'running': print(r) notify_url(r.url) break else: if r.status == 'error': break time.sleep(10) r = api.get_inference_endpoint(name=f"platma-{id}") print(r) NOTIFICATION_TEMPLATE = """\ 🌸 Hello there! Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! (This is an automated message) """ URL_TEMPLATE = """\ Here is your endpoint: {url} (This is an automated message) """ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)