Spaces:
Paused
Paused
add local autotrain
Browse files- src/main.py +60 -130
- src/models.py +16 -13
src/main.py
CHANGED
@@ -2,10 +2,9 @@ import os
|
|
2 |
import requests
|
3 |
from typing import Optional
|
4 |
import uvicorn
|
5 |
-
import subprocess
|
6 |
from subprocess import Popen
|
7 |
|
8 |
-
from fastapi import FastAPI, Header,
|
9 |
from fastapi.responses import FileResponse
|
10 |
from huggingface_hub.hf_api import HfApi
|
11 |
|
@@ -13,151 +12,82 @@ from models import config, WebhookPayload
|
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
-
|
17 |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
|
18 |
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
|
|
|
|
19 |
@app.get("/")
|
20 |
async def home():
|
21 |
-
|
|
|
22 |
|
23 |
@app.post("/webhook")
|
24 |
async def post_webhook(
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
|
50 |
def schedule_retrain(payload: WebhookPayload):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
"params": {
|
83 |
-
"block_size": 1024,
|
84 |
-
"model_max_length": 4096,
|
85 |
-
"max_prompt_length": 512,
|
86 |
-
"epochs": 1,
|
87 |
-
"batch_size": 2,
|
88 |
-
"lr": 0.00003,
|
89 |
-
"peft": True,
|
90 |
-
"quantization": "int4",
|
91 |
-
"target_modules": "all-linear",
|
92 |
-
"padding": "right",
|
93 |
-
"optimizer": "adamw_torch",
|
94 |
-
"scheduler": "linear",
|
95 |
-
"gradient_accumulation": 4,
|
96 |
-
"mixed_precision": "fp16",
|
97 |
-
"chat_template": "chatml"
|
98 |
-
}
|
99 |
-
},
|
100 |
-
headers={
|
101 |
-
"Authorization": f"Bearer {HF_ACCESS_TOKEN}"
|
102 |
-
}
|
103 |
-
)
|
104 |
-
project_resp.raise_for_status()
|
105 |
-
return project_resp.json()
|
106 |
-
|
107 |
-
@staticmethod
|
108 |
-
def add_data(project_id:int):
|
109 |
-
requests.post(
|
110 |
-
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset",
|
111 |
-
json={
|
112 |
-
"dataset_id": config.input_dataset,
|
113 |
-
"dataset_split": "train",
|
114 |
-
"split": 4,
|
115 |
-
"col_mapping": {
|
116 |
-
"image": "image",
|
117 |
-
"label": "target",
|
118 |
-
}
|
119 |
-
},
|
120 |
-
headers={
|
121 |
-
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
|
122 |
-
}
|
123 |
-
).raise_for_status()
|
124 |
-
|
125 |
-
@staticmethod
|
126 |
-
def start_processing(project_id: int):
|
127 |
-
resp = requests.post(
|
128 |
-
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing",
|
129 |
-
headers={
|
130 |
-
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
|
131 |
-
}
|
132 |
-
)
|
133 |
-
resp.raise_for_status()
|
134 |
-
return resp
|
135 |
-
|
136 |
-
|
137 |
-
def notify_success(project_id: int):
|
138 |
-
message = NOTIFICATION_TEMPLATE.format(
|
139 |
-
input_model=config.input_model,
|
140 |
-
input_dataset=config.input_dataset,
|
141 |
-
project_id=project_id,
|
142 |
-
ui_url=AUTOTRAIN_UI_URL,
|
143 |
-
)
|
144 |
-
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
|
145 |
-
repo_id=config.input_dataset,
|
146 |
-
repo_type="dataset",
|
147 |
-
title="✨ Retraining started!",
|
148 |
-
description=message,
|
149 |
-
token=HF_ACCESS_TOKEN,
|
150 |
-
)
|
151 |
|
152 |
NOTIFICATION_TEMPLATE = """\
|
153 |
🌸 Hello there!
|
154 |
|
155 |
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
|
156 |
|
157 |
-
Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job.
|
158 |
-
|
159 |
(This is an automated message)
|
160 |
"""
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
2 |
import requests
|
3 |
from typing import Optional
|
4 |
import uvicorn
|
|
|
5 |
from subprocess import Popen
|
6 |
|
7 |
+
from fastapi import FastAPI, Header, BackgroundTasks
|
8 |
from fastapi.responses import FileResponse
|
9 |
from huggingface_hub.hf_api import HfApi
|
10 |
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
15 |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
|
16 |
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
17 |
+
|
18 |
+
|
19 |
@app.get("/")
|
20 |
async def home():
|
21 |
+
return FileResponse("home.html")
|
22 |
+
|
23 |
|
24 |
@app.post("/webhook")
|
25 |
async def post_webhook(
|
26 |
+
payload: WebhookPayload,
|
27 |
+
task_queue: BackgroundTasks,
|
28 |
+
x_webhook_secret: Optional[str] = Header(default=None),
|
29 |
+
):
|
30 |
+
# if x_webhook_secret is None:
|
31 |
+
# raise HTTPException(401)
|
32 |
+
# if x_webhook_secret != WEBHOOK_SECRET:
|
33 |
+
# raise HTTPException(403)
|
34 |
+
# if not (
|
35 |
+
# payload.event.action == "update"
|
36 |
+
# and payload.event.scope.startswith("repo.content")
|
37 |
+
# and payload.repo.name == config.input_dataset
|
38 |
+
# and payload.repo.type == "dataset"
|
39 |
+
# ):
|
40 |
+
# # no-op
|
41 |
+
# return {"processed": False}
|
42 |
+
schedule_retrain(payload=payload)
|
43 |
+
# task_queue.add_task(
|
44 |
+
# schedule_retrain,
|
45 |
+
# payload
|
46 |
+
# )
|
47 |
+
|
48 |
+
return {"processed": True}
|
49 |
|
50 |
|
51 |
def schedule_retrain(payload: WebhookPayload):
|
52 |
+
# Create the autotrain project
|
53 |
+
try:
|
54 |
+
result = Popen(['autotrain', '--config', 'config.yaml'])
|
55 |
+
# project = AutoTrain.create_project(payload)
|
56 |
+
# AutoTrain.add_data(project_id=project["id"])
|
57 |
+
# AutoTrain.start_processing(project_id=project["id"])
|
58 |
+
except requests.HTTPError as err:
|
59 |
+
print("ERROR while requesting AutoTrain API:")
|
60 |
+
print(f" code: {err.response.status_code}")
|
61 |
+
print(f" {err.response.json()}")
|
62 |
+
raise
|
63 |
+
# Notify in the community tab
|
64 |
+
notify_success('vicuna')
|
65 |
+
print(result.returncode)
|
66 |
+
return {"processed": True}
|
67 |
+
|
68 |
+
|
69 |
+
def notify_success(project_id: str):
|
70 |
+
message = NOTIFICATION_TEMPLATE.format(
|
71 |
+
input_model=config.input_model,
|
72 |
+
input_dataset=config.input_dataset,
|
73 |
+
project_id=project_id,
|
74 |
+
)
|
75 |
+
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
|
76 |
+
repo_id=config.input_dataset,
|
77 |
+
repo_type="dataset",
|
78 |
+
title="✨ Retraining started!",
|
79 |
+
description=message,
|
80 |
+
token=HF_ACCESS_TOKEN,
|
81 |
+
)
|
82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
NOTIFICATION_TEMPLATE = """\
|
85 |
🌸 Hello there!
|
86 |
|
87 |
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
|
88 |
|
|
|
|
|
89 |
(This is an automated message)
|
90 |
"""
|
91 |
|
92 |
if __name__ == "__main__":
|
93 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/models.py
CHANGED
@@ -2,27 +2,30 @@ import os
|
|
2 |
from pydantic import BaseModel
|
3 |
from typing import Literal
|
4 |
|
|
|
5 |
class Config(BaseModel):
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
|
11 |
|
12 |
class WebhookPayloadEvent(BaseModel):
|
13 |
-
|
14 |
-
|
|
|
15 |
|
16 |
class WebhookPayloadRepo(BaseModel):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
22 |
|
23 |
class WebhookPayload(BaseModel):
|
24 |
-
|
25 |
-
|
26 |
|
27 |
|
28 |
config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from typing import Literal
|
4 |
|
5 |
+
|
6 |
class Config(BaseModel):
|
7 |
+
target_namespace: str
|
8 |
+
input_dataset: str
|
9 |
+
input_model: str
|
10 |
+
autotrain_project_prefix: str
|
11 |
|
12 |
|
13 |
class WebhookPayloadEvent(BaseModel):
|
14 |
+
action: Literal["create", "update", "delete"]
|
15 |
+
scope: str
|
16 |
+
|
17 |
|
18 |
class WebhookPayloadRepo(BaseModel):
|
19 |
+
type: Literal["dataset", "model", "space"]
|
20 |
+
name: str
|
21 |
+
id: str
|
22 |
+
private: bool
|
23 |
+
headSha: str
|
24 |
+
|
25 |
|
26 |
class WebhookPayload(BaseModel):
|
27 |
+
event: WebhookPayloadEvent
|
28 |
+
repo: WebhookPayloadRepo
|
29 |
|
30 |
|
31 |
config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))
|