rahul7star commited on
Commit
4235d82
·
verified ·
1 Parent(s): bd4a5d2

Update flux_train.py

Browse files
Files changed (1) hide show
  1. flux_train.py +24 -2
flux_train.py CHANGED
@@ -2,14 +2,28 @@
2
 
3
  import os
4
  from collections import OrderedDict
 
5
  import sys
6
  sys.path.append("/app/ai-toolkit") # Tell Python to look here
7
 
 
8
  from toolkit.job import run_job
9
 
 
 
10
 
 
 
 
 
 
 
 
 
 
11
 
12
- def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun"):
 
13
  job = OrderedDict([
14
  ('job', 'extension'),
15
  ('config', OrderedDict([
@@ -27,7 +41,8 @@ def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="oha
27
  ('save', OrderedDict([
28
  ('dtype', 'float16'),
29
  ('save_every', 250),
30
- ('max_step_saves_to_keep', 4)
 
31
  ])),
32
  ('datasets', [
33
  OrderedDict([
@@ -81,4 +96,11 @@ def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="oha
81
  ('version', '1.0')
82
  ]))
83
  ])
 
 
 
84
  return job
 
 
 
 
 
2
 
3
  import os
4
  from collections import OrderedDict
5
+ from huggingface_hub import whoami
6
  import sys
7
  sys.path.append("/app/ai-toolkit") # Tell Python to look here
8
 
9
+ from toolkit.job import run_job
10
  from toolkit.job import run_job
11
 
12
+ def update_config_push_to_hub(config, push_to_hub: bool, slugged_lora_name: str):
13
+ config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
14
 
15
+ if push_to_hub:
16
+ try:
17
+ username = whoami()["name"]
18
+ except Exception:
19
+ raise RuntimeError(
20
+ "Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?"
21
+ )
22
+ config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
23
+ config["config"]["process"][0]["save"]["hf_private"] = True
24
 
25
+ def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
26
+ slugged_lora_name = lora_name.lower().replace(" ", "_")
27
  job = OrderedDict([
28
  ('job', 'extension'),
29
  ('config', OrderedDict([
 
41
  ('save', OrderedDict([
42
  ('dtype', 'float16'),
43
  ('save_every', 250),
44
+ ('max_step_saves_to_keep', 4),
45
+ # push_to_hub keys added later by update_config_push_to_hub()
46
  ])),
47
  ('datasets', [
48
  OrderedDict([
 
96
  ('version', '1.0')
97
  ]))
98
  ])
99
+
100
+ # Add push to Hub config if requested
101
+ update_config_push_to_hub(job, push_to_hub, slugged_lora_name)
102
  return job
103
+
104
+ def run_training(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
105
+ job = build_job(concept, training_path, lora_name, push_to_hub)
106
+ run_job(job)