baryshych commited on
Commit
a075ab3
·
1 Parent(s): 981ad5c

add local autotrain

Browse files
Files changed (4) hide show
  1. requirements.txt +1 -0
  2. src/config.json +6 -0
  3. src/config.yaml +39 -0
  4. src/main.py +56 -35
requirements.txt CHANGED
@@ -2,3 +2,4 @@ fastapi==0.74.*
2
  requests==2.27.*
3
  huggingface_hub==0.11.*
4
  uvicorn[standard]==0.17.*
 
 
2
  requests==2.27.*
3
  huggingface_hub==0.11.*
4
  uvicorn[standard]==0.17.*
5
+ autotrain-advanced==0.8.12
src/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "target_namespace": "baryshych",
3
+ "input_dataset": "huggingface-projects/auto-retrain-input-dataset",
4
+ "input_model": "microsoft/resnet-50",
5
+ "autotrain_project_prefix": "platma-retrain"
6
+ }
src/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: llm-sft
2
+ base_model: meta-llama/Meta-Llama-3.1-8B-Instruct
3
+ project_name: llama
4
+ log: tensorboard
5
+ backend: local
6
+
7
+ data:
8
+ path: baryshych/platma
9
+ train_split: train
10
+ valid_split: null
11
+ chat_template: null
12
+ column_mapping:
13
+ text_column: text
14
+
15
+ params:
16
+ block_size: 1024
17
+ lr: 1e-4
18
+ warmup_ratio: 0.1
19
+ weight_decay: 0.01
20
+ epochs: 1
21
+ batch_size: 2
22
+ gradient_accumulation: 8
23
+ mixed_precision: fp16
24
+ peft: True
25
+ quantization: null
26
+ lora_r: 16
27
+ lora_alpha: 32
28
+ lora_dropout: 0.05
29
+ unsloth: False
30
+ optimizer: paged_adamw_8bit
31
+ target_modules: all-linear
32
+ padding: right
33
+ optimizer: paged_adamw_8bit
34
+ scheduler: cosine
35
+
36
+ hub:
37
+ username: baryshych
38
+ token: ${HF_ACCESS_TOKEN}
39
+ push_to_hub: True
src/main.py CHANGED
@@ -1,21 +1,21 @@
1
  import os
2
  import requests
3
  from typing import Optional
 
 
 
4
 
5
  from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
6
  from fastapi.responses import FileResponse
7
  from huggingface_hub.hf_api import HfApi
8
 
9
- from .models import config, WebhookPayload
10
-
11
- WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
12
- HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
13
- AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co"
14
- AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co"
15
-
16
 
17
  app = FastAPI()
18
 
 
 
 
19
  @app.get("/")
20
  async def home():
21
  return FileResponse("home.html")
@@ -26,23 +26,23 @@ async def post_webhook(
26
  task_queue: BackgroundTasks,
27
  x_webhook_secret: Optional[str] = Header(default=None),
28
  ):
29
- if x_webhook_secret is None:
30
- raise HTTPException(401)
31
- if x_webhook_secret != WEBHOOK_SECRET:
32
- raise HTTPException(403)
33
- if not (
34
- payload.event.action == "update"
35
- and payload.event.scope.startswith("repo.content")
36
- and payload.repo.name == config.input_dataset
37
- and payload.repo.type == "dataset"
38
- ):
39
- # no-op
40
- return {"processed": False}
41
-
42
- task_queue.add_task(
43
- schedule_retrain,
44
- payload
45
- )
46
 
47
  return {"processed": True}
48
 
@@ -50,17 +50,18 @@ async def post_webhook(
50
  def schedule_retrain(payload: WebhookPayload):
51
  # Create the autotrain project
52
  try:
53
- project = AutoTrain.create_project(payload)
54
- AutoTrain.add_data(project_id=project["id"])
55
- AutoTrain.start_processing(project_id=project["id"])
 
56
  except requests.HTTPError as err:
57
  print("ERROR while requesting AutoTrain API:")
58
  print(f" code: {err.response.status_code}")
59
  print(f" {err.response.json()}")
60
  raise
61
  # Notify in the community tab
62
- notify_success(project["id"])
63
-
64
  return {"processed": True}
65
 
66
 
@@ -68,15 +69,32 @@ class AutoTrain:
68
  @staticmethod
69
  def create_project(payload: WebhookPayload) -> dict:
70
  project_resp = requests.post(
71
- f"{AUTOTRAIN_API_URL}/projects/create",
72
  json={
73
  "username": config.target_namespace,
74
  "proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
75
- "task": 18, # image-multi-class-classification
76
- "config": {
77
- "hub-model": config.input_model,
78
- "max_models": 1,
79
- "language": "unk",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
81
  },
82
  headers={
@@ -140,3 +158,6 @@ Please review and approve the project [here]({ui_url}/{project_id}/trainings) to
140
 
141
  (This is an automated message)
142
  """
 
 
 
 
1
  import os
2
  import requests
3
  from typing import Optional
4
+ import uvicorn
5
+ import subprocess
6
+ from subprocess import Popen
7
 
8
  from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
9
  from fastapi.responses import FileResponse
10
  from huggingface_hub.hf_api import HfApi
11
 
12
+ from models import config, WebhookPayload
 
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
16
+
17
+ WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
18
+ HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
19
  @app.get("/")
20
  async def home():
21
  return FileResponse("home.html")
 
26
  task_queue: BackgroundTasks,
27
  x_webhook_secret: Optional[str] = Header(default=None),
28
  ):
29
+ # if x_webhook_secret is None:
30
+ # raise HTTPException(401)
31
+ # if x_webhook_secret != WEBHOOK_SECRET:
32
+ # raise HTTPException(403)
33
+ # if not (
34
+ # payload.event.action == "update"
35
+ # and payload.event.scope.startswith("repo.content")
36
+ # and payload.repo.name == config.input_dataset
37
+ # and payload.repo.type == "dataset"
38
+ # ):
39
+ # # no-op
40
+ # return {"processed": False}
41
+ schedule_retrain(payload=payload)
42
+ # task_queue.add_task(
43
+ # schedule_retrain,
44
+ # payload
45
+ # )
46
 
47
  return {"processed": True}
48
 
 
50
  def schedule_retrain(payload: WebhookPayload):
51
  # Create the autotrain project
52
  try:
53
+ result = Popen(['autotrain', '--config', 'config.yaml'])
54
+ # project = AutoTrain.create_project(payload)
55
+ # AutoTrain.add_data(project_id=project["id"])
56
+ # AutoTrain.start_processing(project_id=project["id"])
57
  except requests.HTTPError as err:
58
  print("ERROR while requesting AutoTrain API:")
59
  print(f" code: {err.response.status_code}")
60
  print(f" {err.response.json()}")
61
  raise
62
  # Notify in the community tab
63
+ notify_success('vicuna')
64
+ print(result.returncode)
65
  return {"processed": True}
66
 
67
 
 
69
  @staticmethod
70
  def create_project(payload: WebhookPayload) -> dict:
71
  project_resp = requests.post(
72
+ f"{AUTOTRAIN_API_URL}/api/create_project",
73
  json={
74
  "username": config.target_namespace,
75
  "proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
76
+ "task": 'llm:sft', # image-multi-class-classification
77
+ "base_model": "meta-llama/Meta-Llama-3-8B-Instruct",
78
+ "train_split": "train",
79
+ "column_mapping": {
80
+ "text_column": "text",
81
+ },
82
+ "params": {
83
+ "block_size": 1024,
84
+ "model_max_length": 4096,
85
+ "max_prompt_length": 512,
86
+ "epochs": 1,
87
+ "batch_size": 2,
88
+ "lr": 0.00003,
89
+ "peft": True,
90
+ "quantization": "int4",
91
+ "target_modules": "all-linear",
92
+ "padding": "right",
93
+ "optimizer": "adamw_torch",
94
+ "scheduler": "linear",
95
+ "gradient_accumulation": 4,
96
+ "mixed_precision": "fp16",
97
+ "chat_template": "chatml"
98
  }
99
  },
100
  headers={
 
158
 
159
  (This is an automated message)
160
  """
161
+
162
+ if __name__ == "__main__":
163
+ uvicorn.run(app, host="0.0.0.0", port=8000)