baryshych commited on
Commit
a7c7f70
·
1 Parent(s): e882322

add local autotrain

Browse files
Files changed (2) hide show
  1. src/main.py +60 -130
  2. src/models.py +16 -13
src/main.py CHANGED
@@ -2,10 +2,9 @@ import os
2
  import requests
3
  from typing import Optional
4
  import uvicorn
5
- import subprocess
6
  from subprocess import Popen
7
 
8
- from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
9
  from fastapi.responses import FileResponse
10
  from huggingface_hub.hf_api import HfApi
11
 
@@ -13,151 +12,82 @@ from models import config, WebhookPayload
13
 
14
  app = FastAPI()
15
 
16
-
17
  WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
18
  HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
 
 
19
  @app.get("/")
20
  async def home():
21
- return FileResponse("home.html")
 
22
 
23
  @app.post("/webhook")
24
  async def post_webhook(
25
- payload: WebhookPayload,
26
- task_queue: BackgroundTasks,
27
- x_webhook_secret: Optional[str] = Header(default=None),
28
- ):
29
- # if x_webhook_secret is None:
30
- # raise HTTPException(401)
31
- # if x_webhook_secret != WEBHOOK_SECRET:
32
- # raise HTTPException(403)
33
- # if not (
34
- # payload.event.action == "update"
35
- # and payload.event.scope.startswith("repo.content")
36
- # and payload.repo.name == config.input_dataset
37
- # and payload.repo.type == "dataset"
38
- # ):
39
- # # no-op
40
- # return {"processed": False}
41
- schedule_retrain(payload=payload)
42
- # task_queue.add_task(
43
- # schedule_retrain,
44
- # payload
45
- # )
46
-
47
- return {"processed": True}
48
 
49
 
50
  def schedule_retrain(payload: WebhookPayload):
51
- # Create the autotrain project
52
- try:
53
- result = Popen(['autotrain', '--config', 'config.yaml'])
54
- # project = AutoTrain.create_project(payload)
55
- # AutoTrain.add_data(project_id=project["id"])
56
- # AutoTrain.start_processing(project_id=project["id"])
57
- except requests.HTTPError as err:
58
- print("ERROR while requesting AutoTrain API:")
59
- print(f" code: {err.response.status_code}")
60
- print(f" {err.response.json()}")
61
- raise
62
- # Notify in the community tab
63
- notify_success('vicuna')
64
- print(result.returncode)
65
- return {"processed": True}
66
-
67
-
68
- class AutoTrain:
69
- @staticmethod
70
- def create_project(payload: WebhookPayload) -> dict:
71
- project_resp = requests.post(
72
- f"{AUTOTRAIN_API_URL}/api/create_project",
73
- json={
74
- "username": config.target_namespace,
75
- "proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
76
- "task": 'llm:sft', # image-multi-class-classification
77
- "base_model": "meta-llama/Meta-Llama-3-8B-Instruct",
78
- "train_split": "train",
79
- "column_mapping": {
80
- "text_column": "text",
81
- },
82
- "params": {
83
- "block_size": 1024,
84
- "model_max_length": 4096,
85
- "max_prompt_length": 512,
86
- "epochs": 1,
87
- "batch_size": 2,
88
- "lr": 0.00003,
89
- "peft": True,
90
- "quantization": "int4",
91
- "target_modules": "all-linear",
92
- "padding": "right",
93
- "optimizer": "adamw_torch",
94
- "scheduler": "linear",
95
- "gradient_accumulation": 4,
96
- "mixed_precision": "fp16",
97
- "chat_template": "chatml"
98
- }
99
- },
100
- headers={
101
- "Authorization": f"Bearer {HF_ACCESS_TOKEN}"
102
- }
103
- )
104
- project_resp.raise_for_status()
105
- return project_resp.json()
106
-
107
- @staticmethod
108
- def add_data(project_id:int):
109
- requests.post(
110
- f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset",
111
- json={
112
- "dataset_id": config.input_dataset,
113
- "dataset_split": "train",
114
- "split": 4,
115
- "col_mapping": {
116
- "image": "image",
117
- "label": "target",
118
- }
119
- },
120
- headers={
121
- "Authorization": f"Bearer {HF_ACCESS_TOKEN}",
122
- }
123
- ).raise_for_status()
124
-
125
- @staticmethod
126
- def start_processing(project_id: int):
127
- resp = requests.post(
128
- f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing",
129
- headers={
130
- "Authorization": f"Bearer {HF_ACCESS_TOKEN}",
131
- }
132
- )
133
- resp.raise_for_status()
134
- return resp
135
-
136
-
137
- def notify_success(project_id: int):
138
- message = NOTIFICATION_TEMPLATE.format(
139
- input_model=config.input_model,
140
- input_dataset=config.input_dataset,
141
- project_id=project_id,
142
- ui_url=AUTOTRAIN_UI_URL,
143
- )
144
- return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
145
- repo_id=config.input_dataset,
146
- repo_type="dataset",
147
- title="✨ Retraining started!",
148
- description=message,
149
- token=HF_ACCESS_TOKEN,
150
- )
151
 
152
  NOTIFICATION_TEMPLATE = """\
153
  🌸 Hello there!
154
 
155
  Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
156
 
157
- Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job.
158
-
159
  (This is an automated message)
160
  """
161
 
162
  if __name__ == "__main__":
163
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import requests
3
  from typing import Optional
4
  import uvicorn
 
5
  from subprocess import Popen
6
 
7
+ from fastapi import FastAPI, Header, BackgroundTasks
8
  from fastapi.responses import FileResponse
9
  from huggingface_hub.hf_api import HfApi
10
 
 
12
 
13
  app = FastAPI()
14
 
 
15
  WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
16
  HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
17
+
18
+
19
  @app.get("/")
20
  async def home():
21
+ return FileResponse("home.html")
22
+
23
 
24
  @app.post("/webhook")
25
  async def post_webhook(
26
+ payload: WebhookPayload,
27
+ task_queue: BackgroundTasks,
28
+ x_webhook_secret: Optional[str] = Header(default=None),
29
+ ):
30
+ # if x_webhook_secret is None:
31
+ # raise HTTPException(401)
32
+ # if x_webhook_secret != WEBHOOK_SECRET:
33
+ # raise HTTPException(403)
34
+ # if not (
35
+ # payload.event.action == "update"
36
+ # and payload.event.scope.startswith("repo.content")
37
+ # and payload.repo.name == config.input_dataset
38
+ # and payload.repo.type == "dataset"
39
+ # ):
40
+ # # no-op
41
+ # return {"processed": False}
42
+ schedule_retrain(payload=payload)
43
+ # task_queue.add_task(
44
+ # schedule_retrain,
45
+ # payload
46
+ # )
47
+
48
+ return {"processed": True}
49
 
50
 
51
  def schedule_retrain(payload: WebhookPayload):
52
+ # Create the autotrain project
53
+ try:
54
+ result = Popen(['autotrain', '--config', 'config.yaml'])
55
+ # project = AutoTrain.create_project(payload)
56
+ # AutoTrain.add_data(project_id=project["id"])
57
+ # AutoTrain.start_processing(project_id=project["id"])
58
+ except requests.HTTPError as err:
59
+ print("ERROR while requesting AutoTrain API:")
60
+ print(f" code: {err.response.status_code}")
61
+ print(f" {err.response.json()}")
62
+ raise
63
+ # Notify in the community tab
64
+ notify_success('vicuna')
65
+ print(result.returncode)
66
+ return {"processed": True}
67
+
68
+
69
+ def notify_success(project_id: str):
70
+ message = NOTIFICATION_TEMPLATE.format(
71
+ input_model=config.input_model,
72
+ input_dataset=config.input_dataset,
73
+ project_id=project_id,
74
+ )
75
+ return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
76
+ repo_id=config.input_dataset,
77
+ repo_type="dataset",
78
+ title=" Retraining started!",
79
+ description=message,
80
+ token=HF_ACCESS_TOKEN,
81
+ )
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  NOTIFICATION_TEMPLATE = """\
85
  🌸 Hello there!
86
 
87
  Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
88
 
 
 
89
  (This is an automated message)
90
  """
91
 
92
  if __name__ == "__main__":
93
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/models.py CHANGED
@@ -2,27 +2,30 @@ import os
2
  from pydantic import BaseModel
3
  from typing import Literal
4
 
 
5
  class Config(BaseModel):
6
- target_namespace: str
7
- input_dataset: str
8
- input_model: str
9
- autotrain_project_prefix: str
10
 
11
 
12
  class WebhookPayloadEvent(BaseModel):
13
- action: Literal["create", "update", "delete"]
14
- scope: str
 
15
 
16
  class WebhookPayloadRepo(BaseModel):
17
- type: Literal["dataset", "model", "space"]
18
- name: str
19
- id: str
20
- private: bool
21
- headSha: str
 
22
 
23
  class WebhookPayload(BaseModel):
24
- event: WebhookPayloadEvent
25
- repo: WebhookPayloadRepo
26
 
27
 
28
  config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))
 
2
  from pydantic import BaseModel
3
  from typing import Literal
4
 
5
+
6
  class Config(BaseModel):
7
+ target_namespace: str
8
+ input_dataset: str
9
+ input_model: str
10
+ autotrain_project_prefix: str
11
 
12
 
13
  class WebhookPayloadEvent(BaseModel):
14
+ action: Literal["create", "update", "delete"]
15
+ scope: str
16
+
17
 
18
  class WebhookPayloadRepo(BaseModel):
19
+ type: Literal["dataset", "model", "space"]
20
+ name: str
21
+ id: str
22
+ private: bool
23
+ headSha: str
24
+
25
 
26
  class WebhookPayload(BaseModel):
27
+ event: WebhookPayloadEvent
28
+ repo: WebhookPayloadRepo
29
 
30
 
31
  config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))