cocktailpeanut commited on
Commit
ae6e97b
·
1 Parent(s): 182b0f0

user-friendly wandb support

Browse files
Files changed (2) hide show
  1. README.md +21 -0
  2. model/trainer.py +24 -20
README.md CHANGED
@@ -67,6 +67,27 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
67
 
68
  Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ## Inference
71
 
72
  The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
 
67
 
68
  Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
69
 
70
+ ## Wandb Logging
71
+
72
+ By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
73
+
74
+ To turn on wandb logging, you can either:
75
+
76
+ 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
77
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
78
+
79
+ On Mac & Linux:
80
+
81
+ ```
82
+ export WANDB_API_KEY=<YOUR WANDB API KEY>
83
+ ```
84
+
85
+ On Windows:
86
+
87
+ ```
88
+ set WANDB_API_KEY=<YOUR WANDB API KEY>
89
+ ```
90
+
91
  ## Inference
92
 
93
  The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
model/trainer.py CHANGED
@@ -50,31 +50,35 @@ class Trainer:
50
 
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
 
 
 
 
53
  self.accelerator = Accelerator(
54
- log_with = "wandb",
55
  kwargs_handlers = [ddp_kwargs],
56
  gradient_accumulation_steps = grad_accumulation_steps,
57
  **accelerate_kwargs
58
  )
59
-
60
- if exists(wandb_resume_id):
61
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
- else:
63
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
- self.accelerator.init_trackers(
65
- project_name = wandb_project,
66
- init_kwargs=init_kwargs,
67
- config={"epochs": epochs,
68
- "learning_rate": learning_rate,
69
- "num_warmup_updates": num_warmup_updates,
70
- "batch_size": batch_size,
71
- "batch_size_type": batch_size_type,
72
- "max_samples": max_samples,
73
- "grad_accumulation_steps": grad_accumulation_steps,
74
- "max_grad_norm": max_grad_norm,
75
- "gpus": self.accelerator.num_processes,
76
- "noise_scheduler": noise_scheduler}
77
- )
 
78
 
79
  self.model = model
80
 
 
50
 
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
 
53
+ logger = "wandb" if wandb.api.api_key else None
54
+ print(f"Using logger: {logger}")
55
+
56
  self.accelerator = Accelerator(
57
+ log_with = logger,
58
  kwargs_handlers = [ddp_kwargs],
59
  gradient_accumulation_steps = grad_accumulation_steps,
60
  **accelerate_kwargs
61
  )
62
+
63
+ if logger == "wandb":
64
+ if exists(wandb_resume_id):
65
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
66
+ else:
67
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
68
+ self.accelerator.init_trackers(
69
+ project_name = wandb_project,
70
+ init_kwargs=init_kwargs,
71
+ config={"epochs": epochs,
72
+ "learning_rate": learning_rate,
73
+ "num_warmup_updates": num_warmup_updates,
74
+ "batch_size": batch_size,
75
+ "batch_size_type": batch_size_type,
76
+ "max_samples": max_samples,
77
+ "grad_accumulation_steps": grad_accumulation_steps,
78
+ "max_grad_norm": max_grad_norm,
79
+ "gpus": self.accelerator.num_processes,
80
+ "noise_scheduler": noise_scheduler}
81
+ )
82
 
83
  self.model = model
84