Spaces:
Paused
Paused
add local autotrain
Browse files- requirements.txt +1 -0
- src/config.json +6 -0
- src/config.yaml +39 -0
- src/main.py +56 -35
requirements.txt
CHANGED
@@ -2,3 +2,4 @@ fastapi==0.74.*
|
|
2 |
requests==2.27.*
|
3 |
huggingface_hub==0.11.*
|
4 |
uvicorn[standard]==0.17.*
|
|
|
|
2 |
requests==2.27.*
|
3 |
huggingface_hub==0.11.*
|
4 |
uvicorn[standard]==0.17.*
|
5 |
+
autotrain-advanced==0.8.12
|
src/config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"target_namespace": "baryshych",
|
3 |
+
"input_dataset": "huggingface-projects/auto-retrain-input-dataset",
|
4 |
+
"input_model": "microsoft/resnet-50",
|
5 |
+
"autotrain_project_prefix": "platma-retrain"
|
6 |
+
}
|
src/config.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task: llm-sft
|
2 |
+
base_model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
3 |
+
project_name: llama
|
4 |
+
log: tensorboard
|
5 |
+
backend: local
|
6 |
+
|
7 |
+
data:
|
8 |
+
path: baryshych/platma
|
9 |
+
train_split: train
|
10 |
+
valid_split: null
|
11 |
+
chat_template: null
|
12 |
+
column_mapping:
|
13 |
+
text_column: text
|
14 |
+
|
15 |
+
params:
|
16 |
+
block_size: 1024
|
17 |
+
lr: 1e-4
|
18 |
+
warmup_ratio: 0.1
|
19 |
+
weight_decay: 0.01
|
20 |
+
epochs: 1
|
21 |
+
batch_size: 2
|
22 |
+
gradient_accumulation: 8
|
23 |
+
mixed_precision: fp16
|
24 |
+
peft: True
|
25 |
+
quantization: null
|
26 |
+
lora_r: 16
|
27 |
+
lora_alpha: 32
|
28 |
+
lora_dropout: 0.05
|
29 |
+
unsloth: False
|
30 |
+
optimizer: paged_adamw_8bit
|
31 |
+
target_modules: all-linear
|
32 |
+
padding: right
|
33 |
+
optimizer: paged_adamw_8bit
|
34 |
+
scheduler: cosine
|
35 |
+
|
36 |
+
hub:
|
37 |
+
username: baryshych
|
38 |
+
token: ${HF_ACCESS_TOKEN}
|
39 |
+
push_to_hub: True
|
src/main.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
import os
|
2 |
import requests
|
3 |
from typing import Optional
|
|
|
|
|
|
|
4 |
|
5 |
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
|
6 |
from fastapi.responses import FileResponse
|
7 |
from huggingface_hub.hf_api import HfApi
|
8 |
|
9 |
-
from
|
10 |
-
|
11 |
-
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
|
12 |
-
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
13 |
-
AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co"
|
14 |
-
AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co"
|
15 |
-
|
16 |
|
17 |
app = FastAPI()
|
18 |
|
|
|
|
|
|
|
19 |
@app.get("/")
|
20 |
async def home():
|
21 |
return FileResponse("home.html")
|
@@ -26,23 +26,23 @@ async def post_webhook(
|
|
26 |
task_queue: BackgroundTasks,
|
27 |
x_webhook_secret: Optional[str] = Header(default=None),
|
28 |
):
|
29 |
-
if x_webhook_secret is None:
|
30 |
-
|
31 |
-
if x_webhook_secret != WEBHOOK_SECRET:
|
32 |
-
|
33 |
-
if not (
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
task_queue.add_task(
|
43 |
-
|
44 |
-
|
45 |
-
)
|
46 |
|
47 |
return {"processed": True}
|
48 |
|
@@ -50,17 +50,18 @@ async def post_webhook(
|
|
50 |
def schedule_retrain(payload: WebhookPayload):
|
51 |
# Create the autotrain project
|
52 |
try:
|
53 |
-
|
54 |
-
AutoTrain.
|
55 |
-
AutoTrain.
|
|
|
56 |
except requests.HTTPError as err:
|
57 |
print("ERROR while requesting AutoTrain API:")
|
58 |
print(f" code: {err.response.status_code}")
|
59 |
print(f" {err.response.json()}")
|
60 |
raise
|
61 |
# Notify in the community tab
|
62 |
-
notify_success(
|
63 |
-
|
64 |
return {"processed": True}
|
65 |
|
66 |
|
@@ -68,15 +69,32 @@ class AutoTrain:
|
|
68 |
@staticmethod
|
69 |
def create_project(payload: WebhookPayload) -> dict:
|
70 |
project_resp = requests.post(
|
71 |
-
f"{AUTOTRAIN_API_URL}/
|
72 |
json={
|
73 |
"username": config.target_namespace,
|
74 |
"proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
|
75 |
-
"task":
|
76 |
-
"
|
77 |
-
|
78 |
-
|
79 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
}
|
81 |
},
|
82 |
headers={
|
@@ -140,3 +158,6 @@ Please review and approve the project [here]({ui_url}/{project_id}/trainings) to
|
|
140 |
|
141 |
(This is an automated message)
|
142 |
"""
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import requests
|
3 |
from typing import Optional
|
4 |
+
import uvicorn
|
5 |
+
import subprocess
|
6 |
+
from subprocess import Popen
|
7 |
|
8 |
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
|
9 |
from fastapi.responses import FileResponse
|
10 |
from huggingface_hub.hf_api import HfApi
|
11 |
|
12 |
+
from models import config, WebhookPayload
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
|
17 |
+
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
|
18 |
+
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
19 |
@app.get("/")
|
20 |
async def home():
|
21 |
return FileResponse("home.html")
|
|
|
26 |
task_queue: BackgroundTasks,
|
27 |
x_webhook_secret: Optional[str] = Header(default=None),
|
28 |
):
|
29 |
+
# if x_webhook_secret is None:
|
30 |
+
# raise HTTPException(401)
|
31 |
+
# if x_webhook_secret != WEBHOOK_SECRET:
|
32 |
+
# raise HTTPException(403)
|
33 |
+
# if not (
|
34 |
+
# payload.event.action == "update"
|
35 |
+
# and payload.event.scope.startswith("repo.content")
|
36 |
+
# and payload.repo.name == config.input_dataset
|
37 |
+
# and payload.repo.type == "dataset"
|
38 |
+
# ):
|
39 |
+
# # no-op
|
40 |
+
# return {"processed": False}
|
41 |
+
schedule_retrain(payload=payload)
|
42 |
+
# task_queue.add_task(
|
43 |
+
# schedule_retrain,
|
44 |
+
# payload
|
45 |
+
# )
|
46 |
|
47 |
return {"processed": True}
|
48 |
|
|
|
50 |
def schedule_retrain(payload: WebhookPayload):
|
51 |
# Create the autotrain project
|
52 |
try:
|
53 |
+
result = Popen(['autotrain', '--config', 'config.yaml'])
|
54 |
+
# project = AutoTrain.create_project(payload)
|
55 |
+
# AutoTrain.add_data(project_id=project["id"])
|
56 |
+
# AutoTrain.start_processing(project_id=project["id"])
|
57 |
except requests.HTTPError as err:
|
58 |
print("ERROR while requesting AutoTrain API:")
|
59 |
print(f" code: {err.response.status_code}")
|
60 |
print(f" {err.response.json()}")
|
61 |
raise
|
62 |
# Notify in the community tab
|
63 |
+
notify_success('vicuna')
|
64 |
+
print(result.returncode)
|
65 |
return {"processed": True}
|
66 |
|
67 |
|
|
|
69 |
@staticmethod
|
70 |
def create_project(payload: WebhookPayload) -> dict:
|
71 |
project_resp = requests.post(
|
72 |
+
f"{AUTOTRAIN_API_URL}/api/create_project",
|
73 |
json={
|
74 |
"username": config.target_namespace,
|
75 |
"proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
|
76 |
+
"task": 'llm:sft', # image-multi-class-classification
|
77 |
+
"base_model": "meta-llama/Meta-Llama-3-8B-Instruct",
|
78 |
+
"train_split": "train",
|
79 |
+
"column_mapping": {
|
80 |
+
"text_column": "text",
|
81 |
+
},
|
82 |
+
"params": {
|
83 |
+
"block_size": 1024,
|
84 |
+
"model_max_length": 4096,
|
85 |
+
"max_prompt_length": 512,
|
86 |
+
"epochs": 1,
|
87 |
+
"batch_size": 2,
|
88 |
+
"lr": 0.00003,
|
89 |
+
"peft": True,
|
90 |
+
"quantization": "int4",
|
91 |
+
"target_modules": "all-linear",
|
92 |
+
"padding": "right",
|
93 |
+
"optimizer": "adamw_torch",
|
94 |
+
"scheduler": "linear",
|
95 |
+
"gradient_accumulation": 4,
|
96 |
+
"mixed_precision": "fp16",
|
97 |
+
"chat_template": "chatml"
|
98 |
}
|
99 |
},
|
100 |
headers={
|
|
|
158 |
|
159 |
(This is an automated message)
|
160 |
"""
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|