Spaces:
Paused
Paused
import os | |
import requests | |
from typing import Optional | |
import uvicorn | |
from subprocess import Popen | |
import yaml | |
import datetime | |
from fastapi import FastAPI, Header, BackgroundTasks | |
from fastapi.responses import FileResponse | |
from huggingface_hub.hf_api import HfApi | |
from src.models import config, WebhookPayload | |
app = FastAPI() | |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") | |
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") | |
async def home(): | |
return FileResponse("home.html") | |
async def post_webhook( | |
payload: WebhookPayload, | |
task_queue: BackgroundTasks, | |
x_webhook_secret: Optional[str] = Header(default=None), | |
): | |
# if x_webhook_secret is None: | |
# raise HTTPException(401) | |
# if x_webhook_secret != WEBHOOK_SECRET: | |
# raise HTTPException(403) | |
# if not ( | |
# payload.event.action == "update" | |
# and payload.event.scope.startswith("repo.content") | |
# and payload.repo.name == config.input_dataset | |
# and payload.repo.type == "dataset" | |
# ): | |
# # no-op | |
# return {"processed": False} | |
# schedule_retrain(payload=payload) | |
task_queue.add_task( | |
schedule_retrain, | |
payload | |
) | |
return {"processed": True} | |
def schedule_retrain(payload: WebhookPayload): | |
# Create the autotrain project | |
id = str(int(datetime.datetime.now().timestamp())) | |
try: | |
yaml_path = os.path.join(os.getcwd(), "src/config.yaml") | |
with open(yaml_path) as f: | |
list_doc = yaml.safe_load(f) | |
list_doc['project_name'] = id | |
with open(yaml_path, "w") as f: | |
yaml.dump(list_doc, f, default_flow_style=False) | |
result = Popen(['autotrain', '--config', yaml_path]) | |
result.wait() | |
# project = AutoTrain.create_project(payload) | |
# AutoTrain.add_data(project_id=project["id"]) | |
# AutoTrain.start_processing(project_id=project["id"]) | |
except requests.HTTPError as err: | |
print("ERROR while requesting AutoTrain API:") | |
print(f" code: {err.response.status_code}") | |
print(f" {err.response.json()}") | |
raise | |
# Notify in the community tab | |
notify_success(id) | |
deploy_model(id="1726082187") | |
print(result.returncode) | |
return {"processed": True} | |
def notify_success(project_id: str): | |
message = NOTIFICATION_TEMPLATE.format( | |
input_model=config.input_model, | |
input_dataset=config.input_dataset, | |
project_id=project_id, | |
) | |
return HfApi(token=HF_ACCESS_TOKEN).create_discussion( | |
repo_id=config.input_dataset, | |
repo_type="dataset", | |
title="✨ Retraining started!", | |
description=message, | |
token=HF_ACCESS_TOKEN, | |
) | |
def notify_url(url: str): | |
message = URL_TEMPLATE.format( | |
url=url, | |
) | |
return HfApi(token=HF_ACCESS_TOKEN).create_discussion( | |
repo_id='Platma/platma-retrain', | |
repo_type="space", | |
title="✨ Endpoint is ready!", | |
description=message, | |
token=HF_ACCESS_TOKEN, | |
) | |
def deploy_model(id: str): | |
api = HfApi(token=HF_ACCESS_TOKEN) | |
url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma" | |
data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4", | |
"scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout": 15}}, | |
"model": {"framework": "pytorch", "image": { | |
"custom": {"health_route": "/health", | |
"url": "ghcr.io/huggingface/text-generation-inference:sha-f852190", | |
"env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "2048", | |
"MAX_TOTAL_TOKENS": "2512", | |
"MODEL_ID": "/repository"}}}, | |
"repository": f"Platma/{id}", | |
"secrets": {}, | |
"task": "text-generation"}, | |
"name": f"platma-{id}", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"} | |
headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}", "Content-Type": "application/json"} | |
r = requests.post(url, json=data, headers=headers) | |
print(r) | |
r = api.get_inference_endpoint(name=f"platma-{id}") | |
while True: | |
print("Fetching url") | |
if r.status == 'running': | |
print(r) | |
notify_url(r.url) | |
break | |
else: | |
if r.status == 'error': | |
break | |
time.sleep(10) | |
r = api.get_inference_endpoint(name=f"platma-{id}") | |
print(r) | |
NOTIFICATION_TEMPLATE = """\ | |
🌸 Hello there! | |
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! | |
(This is an automated message) | |
""" | |
URL_TEMPLATE = """\ | |
Here is your endpoint: {url} | |
(This is an automated message) | |
""" | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |