rahul7star commited on
Commit
aee06c1
·
verified ·
1 Parent(s): 9a8aa5e

Update flux_train_ui.py

Browse files
Files changed (1) hide show
  1. flux_train_ui.py +36 -47
flux_train_ui.py CHANGED
@@ -23,17 +23,6 @@ import uuid
23
  from slugify import slugify
24
  import gradio as gr # Assuming gr is from gradio for error/warning handling
25
 
26
- os.makedirs("tmp", exist_ok=True)
27
- # Configure logging
28
- logging.basicConfig(
29
- level=logging.DEBUG,
30
- format='%(asctime)s - %(levelname)s - %(message)s',
31
- handlers=[
32
- logging.StreamHandler(), # Output to console
33
- logging.FileHandler('tmp/training.log') # Save logs to a file
34
- ]
35
- )
36
- logger = logging.getLogger(__name__)
37
 
38
  sys.path.insert(0, "ai-toolkit")
39
  from toolkit.job import get_job
@@ -190,8 +179,8 @@ def start_training(
190
  use_more_advanced_options,
191
  more_advanced_options,
192
  ):
193
- logger.info("Starting training process")
194
- logger.debug(f"Input parameters: lora_name={lora_name}, concept_sentence={concept_sentence}, "
195
  f"steps={steps}, lr={lr}, rank={rank}, model_to_train={model_to_train}, "
196
  f"low_vram={low_vram}, dataset_folder={dataset_folder}, "
197
  f"sample_1={sample_1}, sample_2={sample_2}, sample_3={sample_3}, "
@@ -199,44 +188,44 @@ def start_training(
199
  f"more_advanced_options={more_advanced_options}")
200
 
201
  push_to_hub = True
202
- logger.info("Checking LoRA name")
203
  if not lora_name:
204
- logger.error("LoRA name is empty or None")
205
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
206
 
207
  # Check Hugging Face permissions
208
  try:
209
  user_info = whoami()
210
- logger.debug(f"Hugging Face user info: {user_info}")
211
  if user_info["auth"]["accessToken"]["role"] == "write" or \
212
  "repo.edit" in user_info["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
213
- logger.info(f"Starting training locally for user: {user_info['name']}. LoRA will be available locally and on Hugging Face.")
214
  else:
215
  push_to_hub = False
216
- logger.warning("No write access to Hugging Face. Training locally only.")
217
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
218
  except Exception as e:
219
  push_to_hub = False
220
- logger.error(f"Error checking Hugging Face permissions: {str(e)}")
221
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
222
 
223
- logger.info("Training started")
224
  slugged_lora_name = slugify(lora_name)
225
- logger.debug(f"Slugged LoRA name: {slugged_lora_name}")
226
 
227
  # Load the default config
228
  config_path_default = "config/examples/train_lora_flux_24gb.yaml"
229
- logger.info(f"Loading default config from: {config_path_default}")
230
  try:
231
  with open(config_path_default, "r") as f:
232
  config = yaml.safe_load(f)
233
- logger.debug(f"Loaded config: {config}")
234
  except Exception as e:
235
- logger.error(f"Failed to load config from {config_path_default}: {str(e)}")
236
  raise
237
 
238
  # Update the config with user inputs
239
- logger.info("Updating config with user inputs")
240
  try:
241
  config["config"]["name"] = slugged_lora_name
242
  config["config"]["process"][0]["model"]["low_vram"] = low_vram
@@ -247,31 +236,31 @@ def start_training(
247
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
248
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
249
  config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
250
- logger.debug(f"Updated config fields: name={slugged_lora_name}, low_vram={low_vram}, steps={steps}, "
251
  f"lr={lr}, rank={rank}, dataset_folder={dataset_folder}, push_to_hub={push_to_hub}")
252
  except KeyError as e:
253
- logger.error(f"Config structure error: Missing key {str(e)}")
254
  raise
255
  except Exception as e:
256
- logger.error(f"Error updating config: {str(e)}")
257
  raise
258
 
259
  # Handle Hugging Face repository settings
260
  if push_to_hub:
261
  try:
262
  username = whoami()["name"]
263
- logger.debug(f"Hugging Face username: {username}")
264
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
265
  config["config"]["process"][0]["save"]["hf_private"] = True
266
- logger.debug(f"Set Hugging Face repo: {username}/{slugged_lora_name}")
267
  except Exception as e:
268
- logger.error(f"Error retrieving Hugging Face username: {str(e)}")
269
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
270
 
271
  # Handle concept sentence
272
  if concept_sentence:
273
  config["config"]["process"][0]["trigger_word"] = concept_sentence
274
- logger.debug(f"Set trigger_word: {concept_sentence}")
275
 
276
  # Handle sampling prompts
277
  if sample_1 or sample_2 or sample_3:
@@ -285,56 +274,56 @@ def start_training(
285
  config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
286
  if sample_3:
287
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
288
- logger.debug(f"Sampling enabled with prompts: {config['config']['process'][0]['sample']['prompts']}")
289
  else:
290
  config["config"]["process"][0]["train"]["disable_sampling"] = True
291
- logger.debug("Sampling disabled")
292
 
293
  # Handle model selection
294
  if model_to_train == "schnell":
295
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
296
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
297
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
298
- logger.debug("Using schnell model configuration")
299
 
300
  # Handle advanced options
301
  if use_more_advanced_options:
302
  try:
303
  more_advanced_options_dict = yaml.safe_load(more_advanced_options)
304
- logger.debug(f"Advanced options parsed: {more_advanced_options_dict}")
305
  config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
306
- logger.debug(f"Config after advanced options update: {config}")
307
  except Exception as e:
308
- logger.error(f"Error parsing or applying advanced options: {str(e)}")
309
  raise
310
 
311
  # Save the updated config
312
- logger.info("Saving updated config")
313
  random_config_name = str(uuid.uuid4())
314
  os.makedirs("tmp", exist_ok=True)
315
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
316
  try:
317
  with open(config_path, "w") as f:
318
  yaml.dump(config, f)
319
- logger.info(f"Config saved to: {config_path}")
320
  except Exception as e:
321
- logger.error(f"Error saving config to {config_path}: {str(e)}")
322
  raise
323
 
324
  # Run the training job
325
- logger.info(f"Starting training job with config: {config_path}")
326
  try:
327
  job = get_job(config_path)
328
- logger.debug("Job object created successfully")
329
  job.run()
330
- logger.info("Training job completed")
331
  job.cleanup()
332
- logger.info("Job cleanup completed")
333
  except Exception as e:
334
- logger.error(f"Error during training job execution: {str(e)}")
335
  raise
336
 
337
- logger.info(f"Training completed successfully. Model saved as {slugged_lora_name}")
338
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
339
 
340
 
 
23
  from slugify import slugify
24
  import gradio as gr # Assuming gr is from gradio for error/warning handling
25
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  sys.path.insert(0, "ai-toolkit")
28
  from toolkit.job import get_job
 
179
  use_more_advanced_options,
180
  more_advanced_options,
181
  ):
182
+ print("Starting training process")
183
+ print(f"Input parameters: lora_name={lora_name}, concept_sentence={concept_sentence}, "
184
  f"steps={steps}, lr={lr}, rank={rank}, model_to_train={model_to_train}, "
185
  f"low_vram={low_vram}, dataset_folder={dataset_folder}, "
186
  f"sample_1={sample_1}, sample_2={sample_2}, sample_3={sample_3}, "
 
188
  f"more_advanced_options={more_advanced_options}")
189
 
190
  push_to_hub = True
191
+ print("Checking LoRA name")
192
  if not lora_name:
193
+ print("LoRA name is empty or None")
194
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
195
 
196
  # Check Hugging Face permissions
197
  try:
198
  user_info = whoami()
199
+ print(f"Hugging Face user info: {user_info}")
200
  if user_info["auth"]["accessToken"]["role"] == "write" or \
201
  "repo.edit" in user_info["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
202
+ print(f"Starting training locally for user: {user_info['name']}. LoRA will be available locally and on Hugging Face.")
203
  else:
204
  push_to_hub = False
205
+ print("No write access to Hugging Face. Training locally only.")
206
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
207
  except Exception as e:
208
  push_to_hub = False
209
+ print(f"Error checking Hugging Face permissions: {str(e)}")
210
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
211
 
212
+ print("Training started")
213
  slugged_lora_name = slugify(lora_name)
214
+ print(f"Slugged LoRA name: {slugged_lora_name}")
215
 
216
  # Load the default config
217
  config_path_default = "config/examples/train_lora_flux_24gb.yaml"
218
+ print(f"Loading default config from: {config_path_default}")
219
  try:
220
  with open(config_path_default, "r") as f:
221
  config = yaml.safe_load(f)
222
+ print(f"Loaded config: {config}")
223
  except Exception as e:
224
+ print(f"Failed to load config from {config_path_default}: {str(e)}")
225
  raise
226
 
227
  # Update the config with user inputs
228
+ print("Updating config with user inputs")
229
  try:
230
  config["config"]["name"] = slugged_lora_name
231
  config["config"]["process"][0]["model"]["low_vram"] = low_vram
 
236
  config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
237
  config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
238
  config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
239
+ print(f"Updated config fields: name={slugged_lora_name}, low_vram={low_vram}, steps={steps}, "
240
  f"lr={lr}, rank={rank}, dataset_folder={dataset_folder}, push_to_hub={push_to_hub}")
241
  except KeyError as e:
242
+ print(f"Config structure error: Missing key {str(e)}")
243
  raise
244
  except Exception as e:
245
+ print(f"Error updating config: {str(e)}")
246
  raise
247
 
248
  # Handle Hugging Face repository settings
249
  if push_to_hub:
250
  try:
251
  username = whoami()["name"]
252
+ print(f"Hugging Face username: {username}")
253
  config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
254
  config["config"]["process"][0]["save"]["hf_private"] = True
255
+ print(f"Set Hugging Face repo: {username}/{slugged_lora_name}")
256
  except Exception as e:
257
+ print(f"Error retrieving Hugging Face username: {str(e)}")
258
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
259
 
260
  # Handle concept sentence
261
  if concept_sentence:
262
  config["config"]["process"][0]["trigger_word"] = concept_sentence
263
+ print(f"Set trigger_word: {concept_sentence}")
264
 
265
  # Handle sampling prompts
266
  if sample_1 or sample_2 or sample_3:
 
274
  config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
275
  if sample_3:
276
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
277
+ print(f"Sampling enabled with prompts: {config['config']['process'][0]['sample']['prompts']}")
278
  else:
279
  config["config"]["process"][0]["train"]["disable_sampling"] = True
280
+ print("Sampling disabled")
281
 
282
  # Handle model selection
283
  if model_to_train == "schnell":
284
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
285
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
286
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
287
+ print("Using schnell model configuration")
288
 
289
  # Handle advanced options
290
  if use_more_advanced_options:
291
  try:
292
  more_advanced_options_dict = yaml.safe_load(more_advanced_options)
293
+ print(f"Advanced options parsed: {more_advanced_options_dict}")
294
  config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
295
+ print(f"Config after advanced options update: {config}")
296
  except Exception as e:
297
+ print(f"Error parsing or applying advanced options: {str(e)}")
298
  raise
299
 
300
  # Save the updated config
301
+ print("Saving updated config")
302
  random_config_name = str(uuid.uuid4())
303
  os.makedirs("tmp", exist_ok=True)
304
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
305
  try:
306
  with open(config_path, "w") as f:
307
  yaml.dump(config, f)
308
+ print(f"Config saved to: {config_path}")
309
  except Exception as e:
310
+ print(f"Error saving config to {config_path}: {str(e)}")
311
  raise
312
 
313
  # Run the training job
314
+ print(f"Starting training job with config: {config_path}")
315
  try:
316
  job = get_job(config_path)
317
+ print("Job object created successfully")
318
  job.run()
319
+ print("Training job completed")
320
  job.cleanup()
321
+ print("Job cleanup completed")
322
  except Exception as e:
323
+ print(f"Error during training job execution: {str(e)}")
324
  raise
325
 
326
+ print(f"Training completed successfully. Model saved as {slugged_lora_name}")
327
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
328
 
329