alexnasa commited on
Commit
6cc150a
·
verified ·
1 Parent(s): 7ac53c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -12,22 +12,22 @@ import torch
12
 
13
  print(f'torch version:{torch.__version__}')
14
 
15
- import subprocess
16
- import importlib, site, sys
17
 
18
- # Re-discover all .pth/.egg-link files
19
- for sitedir in site.getsitepackages():
20
- site.addsitedir(sitedir)
21
 
22
- # Clear caches so importlib will pick up new modules
23
- importlib.invalidate_caches()
24
 
25
- def sh(cmd): subprocess.check_call(cmd, shell=True)
26
 
27
- sh("pip install -U xformers --index-url https://download.pytorch.org/whl/cu126")
28
 
29
- # tell Python to re-scan site-packages now that the egg-link exists
30
- import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
31
 
32
  import torch.utils.checkpoint
33
  from pytorch_lightning import seed_everything
@@ -91,10 +91,10 @@ text_encoder.requires_grad_(False)
91
  unet.requires_grad_(False)
92
  controlnet.requires_grad_(False)
93
 
94
- unet.to("cuda")
95
- controlnet.to("cuda")
96
- unet.enable_xformers_memory_efficient_attention()
97
- controlnet.enable_xformers_memory_efficient_attention()
98
 
99
  # Get the validation pipeline
100
  validation_pipeline = StableDiffusionControlNetPipeline(
 
12
 
13
  print(f'torch version:{torch.__version__}')
14
 
15
+ # import subprocess
16
+ # import importlib, site, sys
17
 
18
+ # # Re-discover all .pth/.egg-link files
19
+ # for sitedir in site.getsitepackages():
20
+ # site.addsitedir(sitedir)
21
 
22
+ # # Clear caches so importlib will pick up new modules
23
+ # importlib.invalidate_caches()
24
 
25
+ # def sh(cmd): subprocess.check_call(cmd, shell=True)
26
 
27
+ # sh("pip install -U xformers --index-url https://download.pytorch.org/whl/cu126")
28
 
29
+ # # tell Python to re-scan site-packages now that the egg-link exists
30
+ # import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
31
 
32
  import torch.utils.checkpoint
33
  from pytorch_lightning import seed_everything
 
91
  unet.requires_grad_(False)
92
  controlnet.requires_grad_(False)
93
 
94
+ # unet.to("cuda")
95
+ # controlnet.to("cuda")
96
+ # unet.enable_xformers_memory_efficient_attention()
97
+ # controlnet.enable_xformers_memory_efficient_attention()
98
 
99
  # Get the validation pipeline
100
  validation_pipeline = StableDiffusionControlNetPipeline(