Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
6067469
1
Parent(s):
c8aa68b
save every N steps and loss logging
Browse files- StableDiffuser.py +3 -3
- app.py +40 -18
- isolate_rng.py +73 -0
- train.py +22 -2
StableDiffuser.py
CHANGED
|
@@ -95,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
|
|
| 95 |
def set_scheduler_timesteps(self, n_steps):
|
| 96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
| 97 |
|
| 98 |
-
def get_initial_latents(self, n_imgs,
|
| 99 |
-
noise = self.get_noise(n_imgs,
|
| 100 |
latents = noise * self.scheduler.init_noise_sigma
|
| 101 |
return latents
|
| 102 |
|
|
@@ -199,7 +199,7 @@ class StableDiffuser(torch.nn.Module):
|
|
| 199 |
prompts = [prompts]
|
| 200 |
|
| 201 |
self.set_scheduler_timesteps(n_steps)
|
| 202 |
-
latents = self.get_initial_latents(n_imgs,
|
| 203 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
| 204 |
end_iteration = end_iteration or n_steps
|
| 205 |
latents_steps, trace_steps = self.diffusion(
|
|
|
|
| 95 |
def set_scheduler_timesteps(self, n_steps):
|
| 96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
| 97 |
|
| 98 |
+
def get_initial_latents(self, n_imgs, height, width, n_prompts, generator=None):
|
| 99 |
+
noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
|
| 100 |
latents = noise * self.scheduler.init_noise_sigma
|
| 101 |
return latents
|
| 102 |
|
|
|
|
| 199 |
prompts = [prompts]
|
| 200 |
|
| 201 |
self.set_scheduler_timesteps(n_steps)
|
| 202 |
+
latents = self.get_initial_latents(n_imgs, height, width, len(prompts), generator=generator)
|
| 203 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
| 204 |
end_iteration = end_iteration or n_steps
|
| 205 |
latents_steps, trace_steps = self.diffusion(
|
app.py
CHANGED
|
@@ -10,22 +10,16 @@ from memory_efficiency import MemoryEfficiencyWrapper
|
|
| 10 |
from train import train
|
| 11 |
|
| 12 |
import os
|
| 13 |
-
model_map = {'Van Gogh': 'models/vangogh.pt',
|
| 14 |
-
'Pablo Picasso': 'models/pablopicasso.pt',
|
| 15 |
-
'Car': 'models/car.pt',
|
| 16 |
-
'Garbage Truck': 'models/garbagetruck.pt',
|
| 17 |
-
'French Horn': 'models/frenchhorn.pt',
|
| 18 |
-
'Kilian Eng': 'models/kilianeng.pt',
|
| 19 |
-
'Thomas Kinkade': 'models/thomaskinkade.pt',
|
| 20 |
-
'Tyler Edlin': 'models/tyleredlin.pt',
|
| 21 |
-
'Kelly McKernan': 'models/kellymckernan.pt',
|
| 22 |
-
'Rembrandt': 'models/rembrandt.pt' }
|
| 23 |
-
for model_file in os.listdir('models'):
|
| 24 |
-
path = 'models/' + model_file
|
| 25 |
-
if any([existing_path == path for existing_path in model_map.values()]):
|
| 26 |
-
continue
|
| 27 |
-
model_map[model_file] = path
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
|
| 31 |
SPACE_ID = os.getenv('SPACE_ID')
|
|
@@ -85,6 +79,10 @@ class Demo:
|
|
| 85 |
value='Van Gogh',
|
| 86 |
interactive=True
|
| 87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
self.seed_infr = gr.Number(
|
| 90 |
label="Seed",
|
|
@@ -196,6 +194,11 @@ class Demo:
|
|
| 196 |
label="Seed",
|
| 197 |
info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
|
| 198 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
with gr.Column():
|
| 201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
|
@@ -215,6 +218,10 @@ class Demo:
|
|
| 215 |
value="Train",
|
| 216 |
)
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
self.download = gr.Files()
|
| 219 |
|
| 220 |
with gr.Tab("Export") as export_column:
|
|
@@ -268,7 +275,10 @@ class Demo:
|
|
| 268 |
self.image_orig
|
| 269 |
]
|
| 270 |
)
|
| 271 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 272 |
self.train_model_input,
|
| 273 |
self.train_img_size_input,
|
| 274 |
self.prompt_input,
|
|
@@ -281,9 +291,12 @@ class Demo:
|
|
| 281 |
self.train_use_amp_input,
|
| 282 |
self.train_use_gradient_checkpointing_input,
|
| 283 |
self.train_seed_input,
|
|
|
|
| 284 |
],
|
| 285 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 286 |
)
|
|
|
|
|
|
|
| 287 |
self.export_button.click(self.export, inputs = [
|
| 288 |
self.model_dropdown_export,
|
| 289 |
self.base_repo_id_or_path_input_export,
|
|
@@ -293,9 +306,15 @@ class Demo:
|
|
| 293 |
outputs=[self.export_status]
|
| 294 |
)
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 297 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
| 298 |
-
seed=-1,
|
| 299 |
pbar = gr.Progress(track_tqdm=True)):
|
| 300 |
|
| 301 |
if self.training:
|
|
@@ -331,10 +350,13 @@ class Demo:
|
|
| 331 |
|
| 332 |
try:
|
| 333 |
self.training = True
|
|
|
|
| 334 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
| 335 |
-
|
|
|
|
| 336 |
finally:
|
| 337 |
self.training = False
|
|
|
|
| 338 |
|
| 339 |
torch.cuda.empty_cache()
|
| 340 |
|
|
|
|
| 10 |
from train import train
|
| 11 |
|
| 12 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
def populate_model_map():
|
| 15 |
+
model_map = {}
|
| 16 |
+
for model_file in os.listdir('models'):
|
| 17 |
+
path = 'models/' + model_file
|
| 18 |
+
if any([existing_path == path for existing_path in model_map.values()]):
|
| 19 |
+
continue
|
| 20 |
+
model_map[model_file] = path
|
| 21 |
+
return model_map
|
| 22 |
+
model_map = populate_model_map()
|
| 23 |
|
| 24 |
ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
|
| 25 |
SPACE_ID = os.getenv('SPACE_ID')
|
|
|
|
| 79 |
value='Van Gogh',
|
| 80 |
interactive=True
|
| 81 |
)
|
| 82 |
+
self.model_reload_button = gr.Button(
|
| 83 |
+
value="🔄",
|
| 84 |
+
interactive=True
|
| 85 |
+
)
|
| 86 |
|
| 87 |
self.seed_infr = gr.Number(
|
| 88 |
label="Seed",
|
|
|
|
| 194 |
label="Seed",
|
| 195 |
info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
|
| 196 |
)
|
| 197 |
+
self.train_save_every_input = gr.Number(
|
| 198 |
+
value=-1,
|
| 199 |
+
label="Save every N steps",
|
| 200 |
+
info="If >0, save the model throughout training at the given step interval."
|
| 201 |
+
)
|
| 202 |
|
| 203 |
with gr.Column():
|
| 204 |
self.train_memory_options = gr.Markdown(interactive=False,
|
|
|
|
| 218 |
value="Train",
|
| 219 |
)
|
| 220 |
|
| 221 |
+
self.train_cancel_button = gr.Button(
|
| 222 |
+
value="Cancel training"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
self.download = gr.Files()
|
| 226 |
|
| 227 |
with gr.Tab("Export") as export_column:
|
|
|
|
| 275 |
self.image_orig
|
| 276 |
]
|
| 277 |
)
|
| 278 |
+
self.model_reload_button.click(self.reload_models,
|
| 279 |
+
inputs=[self.model_dropdown],
|
| 280 |
+
outputs=[self.model_dropdown])
|
| 281 |
+
train_event = self.train_button.click(self.train, inputs = [
|
| 282 |
self.train_model_input,
|
| 283 |
self.train_img_size_input,
|
| 284 |
self.prompt_input,
|
|
|
|
| 291 |
self.train_use_amp_input,
|
| 292 |
self.train_use_gradient_checkpointing_input,
|
| 293 |
self.train_seed_input,
|
| 294 |
+
self.train_save_every_input,
|
| 295 |
],
|
| 296 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 297 |
)
|
| 298 |
+
self.train_cancel_button.click(lambda x: print("cancel pressed"), cancels=[train_event])
|
| 299 |
+
|
| 300 |
self.export_button.click(self.export, inputs = [
|
| 301 |
self.model_dropdown_export,
|
| 302 |
self.base_repo_id_or_path_input_export,
|
|
|
|
| 306 |
outputs=[self.export_status]
|
| 307 |
)
|
| 308 |
|
| 309 |
+
def reload_models(self, model_dropdown):
|
| 310 |
+
current_model_name = model_dropdown
|
| 311 |
+
global model_map
|
| 312 |
+
model_map = populate_model_map()
|
| 313 |
+
return [gr.Dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
|
| 314 |
+
|
| 315 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 316 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
| 317 |
+
seed=-1, save_every=-1,
|
| 318 |
pbar = gr.Progress(track_tqdm=True)):
|
| 319 |
|
| 320 |
if self.training:
|
|
|
|
| 350 |
|
| 351 |
try:
|
| 352 |
self.training = True
|
| 353 |
+
self.train_cancel_button.update(interactive=True)
|
| 354 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
| 355 |
+
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
|
| 356 |
+
seed=int(seed), save_every=int(save_every))
|
| 357 |
finally:
|
| 358 |
self.training = False
|
| 359 |
+
self.train_cancel_button.update(interactive=False)
|
| 360 |
|
| 361 |
torch.cuda.empty_cache()
|
| 362 |
|
isolate_rng.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy/pasted from pytorch lightning
|
| 2 |
+
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
|
| 3 |
+
# and
|
| 4 |
+
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
|
| 5 |
+
|
| 6 |
+
# Copyright The Lightning team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
from contextlib import contextmanager
|
| 21 |
+
from typing import Generator, Dict, Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
from random import getstate as python_get_rng_state
|
| 26 |
+
from random import setstate as python_set_rng_state
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
| 30 |
+
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
| 31 |
+
states = {
|
| 32 |
+
"torch": torch.get_rng_state(),
|
| 33 |
+
"numpy": np.random.get_state(),
|
| 34 |
+
"python": python_get_rng_state(),
|
| 35 |
+
}
|
| 36 |
+
if include_cuda:
|
| 37 |
+
states["torch.cuda"] = torch.cuda.get_rng_state_all()
|
| 38 |
+
return states
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
| 42 |
+
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
| 43 |
+
process."""
|
| 44 |
+
torch.set_rng_state(rng_state_dict["torch"])
|
| 45 |
+
# torch.cuda rng_state is only included since v1.8.
|
| 46 |
+
if "torch.cuda" in rng_state_dict:
|
| 47 |
+
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
|
| 48 |
+
np.random.set_state(rng_state_dict["numpy"])
|
| 49 |
+
version, state, gauss = rng_state_dict["python"]
|
| 50 |
+
python_set_rng_state((version, tuple(state), gauss))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@contextmanager
|
| 54 |
+
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
|
| 55 |
+
"""A context manager that resets the global random state on exit to what it was before entering.
|
| 56 |
+
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
|
| 57 |
+
Args:
|
| 58 |
+
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
|
| 59 |
+
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
|
| 60 |
+
prohibited.
|
| 61 |
+
Example:
|
| 62 |
+
>>> import torch
|
| 63 |
+
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
|
| 64 |
+
<torch._C.Generator object at ...>
|
| 65 |
+
>>> with isolate_rng():
|
| 66 |
+
... [torch.rand(1) for _ in range(3)]
|
| 67 |
+
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
|
| 68 |
+
>>> torch.rand(1)
|
| 69 |
+
tensor([0.7576])
|
| 70 |
+
"""
|
| 71 |
+
states = _collect_rng_states(include_cuda)
|
| 72 |
+
yield
|
| 73 |
+
_set_rng_states(states)
|
train.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
from accelerate.utils import set_seed
|
| 4 |
from torch.cuda.amp import autocast
|
|
@@ -8,11 +8,12 @@ from finetuning import FineTunedModel
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
|
|
|
| 11 |
from memory_efficiency import MemoryEfficiencyWrapper
|
| 12 |
|
| 13 |
|
| 14 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
| 15 |
-
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):
|
| 16 |
|
| 17 |
nsteps = 50
|
| 18 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
|
@@ -54,6 +55,9 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 54 |
seed = random.randint(0, 2 ** 30)
|
| 55 |
set_seed(int(seed))
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
for i in pbar:
|
| 58 |
with torch.no_grad():
|
| 59 |
diffuser.set_scheduler_timesteps(nsteps)
|
|
@@ -92,6 +96,22 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 92 |
memory_efficiency_wrapper.step(optimizer, loss)
|
| 93 |
optimizer.zero_grad()
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
torch.save(finetuner.state_dict(), save_path)
|
| 96 |
|
| 97 |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
|
|
|
| 1 |
+
import random
|
| 2 |
|
| 3 |
from accelerate.utils import set_seed
|
| 4 |
from torch.cuda.amp import autocast
|
|
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
+
from isolate_rng import isolate_rng
|
| 12 |
from memory_efficiency import MemoryEfficiencyWrapper
|
| 13 |
|
| 14 |
|
| 15 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
| 16 |
+
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):
|
| 17 |
|
| 18 |
nsteps = 50
|
| 19 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
|
|
|
| 55 |
seed = random.randint(0, 2 ** 30)
|
| 56 |
set_seed(int(seed))
|
| 57 |
|
| 58 |
+
prev_losses = []
|
| 59 |
+
start_loss = None
|
| 60 |
+
max_prev_loss_count = 10
|
| 61 |
for i in pbar:
|
| 62 |
with torch.no_grad():
|
| 63 |
diffuser.set_scheduler_timesteps(nsteps)
|
|
|
|
| 96 |
memory_efficiency_wrapper.step(optimizer, loss)
|
| 97 |
optimizer.zero_grad()
|
| 98 |
|
| 99 |
+
# print moving average loss
|
| 100 |
+
prev_losses.append(loss.detach().clone())
|
| 101 |
+
if len(prev_losses) > max_prev_loss_count:
|
| 102 |
+
prev_losses.pop(0)
|
| 103 |
+
if start_loss is None:
|
| 104 |
+
start_loss = prev_losses[-1]
|
| 105 |
+
if len(prev_losses) >= max_prev_loss_count:
|
| 106 |
+
moving_average_loss = sum(prev_losses) / len(prev_losses)
|
| 107 |
+
print(
|
| 108 |
+
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
|
| 109 |
+
else:
|
| 110 |
+
print(f"step {i}: loss={loss.item()}")
|
| 111 |
+
|
| 112 |
+
if save_every > 0 and ((i % save_every) == (save_every-1)):
|
| 113 |
+
torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")
|
| 114 |
+
|
| 115 |
torch.save(finetuner.state_dict(), save_path)
|
| 116 |
|
| 117 |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|