CristianLazoQuispe commited on
Commit
a2afb6d
Β·
1 Parent(s): 31fff2a

load model in session

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  import gradio as gr
5
  logging.basicConfig(level=logging.INFO)
6
  from src.utils import generate_centered_gaussian_noise
7
- from src.demo import resize,plot_flow,plot_diff,load_model_diff,load_model_flow_localized,load_model_flow_standard
8
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  img_shape = (1, 28, 28)
@@ -17,15 +17,15 @@ alphas = 1.0 - betas
17
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
18
 
19
 
20
- #model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
21
- #not catching models because of memory limit in free deployment
22
 
23
  @torch.no_grad()
24
  def generate_diffusion_intermediates_streaming(label):
25
  logging.info("πŸš€ Starting Diffusion Generation")
26
  total_start = time.time()
27
 
28
- model_diff = load_model_diff(ENV,device=device)
29
 
30
 
31
  x = torch.randn(1, *img_shape).to(device)
@@ -90,11 +90,11 @@ def generate_flow_intermediates_streaming(label, noise_type):
90
  # Select noise and model
91
  if noise_type == "Localized":
92
  x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
93
- model_flow = load_model_flow_localized(ENV,device=device)
94
 
95
  else:
96
  x = torch.randn(1, *img_shape).to(device)
97
- model_flow = load_model_flow_standard(ENV,device=device)
98
 
99
  y = torch.full((1,), label, dtype=torch.long, device=device)
100
  steps = 50
 
4
  import gradio as gr
5
  logging.basicConfig(level=logging.INFO)
6
  from src.utils import generate_centered_gaussian_noise
7
+ from src.demo import resize,plot_flow,plot_diff,load_models,load_model_diff,load_model_flow_localized,load_model_flow_standard
8
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  img_shape = (1, 28, 28)
 
17
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
18
 
19
 
20
+ model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
21
+ ###not catching models because of memory limit in free deployment
22
 
23
  @torch.no_grad()
24
  def generate_diffusion_intermediates_streaming(label):
25
  logging.info("πŸš€ Starting Diffusion Generation")
26
  total_start = time.time()
27
 
28
+ #model_diff = load_model_diff(ENV,device=device)
29
 
30
 
31
  x = torch.randn(1, *img_shape).to(device)
 
90
  # Select noise and model
91
  if noise_type == "Localized":
92
  x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
93
+ model_flow = model_flow_localized # load_model_flow_localized(ENV,device=device)
94
 
95
  else:
96
  x = torch.randn(1, *img_shape).to(device)
97
+ model_flow = model_flow_standard #load_model_flow_standard(ENV,device=device)
98
 
99
  y = torch.full((1,), label, dtype=torch.long, device=device)
100
  steps = 50