CristianLazoQuispe commited on
Commit
31fff2a
Β·
1 Parent(s): 74ec8db

load model in session

Browse files
Files changed (2) hide show
  1. app.py +7 -4
  2. src/demo.py +36 -0
app.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  import gradio as gr
5
  logging.basicConfig(level=logging.INFO)
6
  from src.utils import generate_centered_gaussian_noise
7
- from src.demo import resize,plot_flow,load_models,plot_diff
8
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  img_shape = (1, 28, 28)
@@ -16,14 +16,16 @@ betas = torch.linspace(1e-4, 0.02, timesteps)
16
  alphas = 1.0 - betas
17
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
18
 
19
- model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
20
 
 
 
21
 
22
  @torch.no_grad()
23
  def generate_diffusion_intermediates_streaming(label):
24
  logging.info("πŸš€ Starting Diffusion Generation")
25
  total_start = time.time()
26
 
 
27
 
28
 
29
  x = torch.randn(1, *img_shape).to(device)
@@ -88,10 +90,11 @@ def generate_flow_intermediates_streaming(label, noise_type):
88
  # Select noise and model
89
  if noise_type == "Localized":
90
  x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
91
- model_flow = model_flow_localized
 
92
  else:
93
  x = torch.randn(1, *img_shape).to(device)
94
- model_flow = model_flow_standard
95
 
96
  y = torch.full((1,), label, dtype=torch.long, device=device)
97
  steps = 50
 
4
  import gradio as gr
5
  logging.basicConfig(level=logging.INFO)
6
  from src.utils import generate_centered_gaussian_noise
7
+ from src.demo import resize,plot_flow,plot_diff,load_model_diff,load_model_flow_localized,load_model_flow_standard
8
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  img_shape = (1, 28, 28)
 
16
  alphas = 1.0 - betas
17
  alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
18
 
 
19
 
20
+ #model_diff,model_flow_standard,model_flow_localized = load_models(ENV,device=device)
21
+ #not catching models because of memory limit in free deployment
22
 
23
  @torch.no_grad()
24
  def generate_diffusion_intermediates_streaming(label):
25
  logging.info("πŸš€ Starting Diffusion Generation")
26
  total_start = time.time()
27
 
28
+ model_diff = load_model_diff(ENV,device=device)
29
 
30
 
31
  x = torch.randn(1, *img_shape).to(device)
 
90
  # Select noise and model
91
  if noise_type == "Localized":
92
  x = generate_centered_gaussian_noise((1, *img_shape)).to(device)
93
+ model_flow = load_model_flow_localized(ENV,device=device)
94
+
95
  else:
96
  x = torch.randn(1, *img_shape).to(device)
97
+ model_flow = load_model_flow_standard(ENV,device=device)
98
 
99
  y = torch.full((1,), label, dtype=torch.long, device=device)
100
  steps = 50
src/demo.py CHANGED
@@ -31,6 +31,42 @@ def load_models(ENV,device):
31
  model_flow_localized.eval()
32
 
33
  return model_diff_standard,model_flow_standard,model_flow_localized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def resize(image,size=(200,200)):
35
  stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
36
  return stretch_near
 
31
  model_flow_localized.eval()
32
 
33
  return model_diff_standard,model_flow_standard,model_flow_localized
34
+
35
+
36
+ def load_model_diff(ENV,device):
37
+ if ENV=="DEPLOY":
38
+ model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
39
+ else:
40
+ model_path = "outputs/diffusion/diffusion_model.pth"
41
+ print("Diff Downloaded!")
42
+ model_diff_standard = ConditionalUNet().to(device)
43
+ model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
44
+ model_diff_standard.eval()
45
+ return model_diff_standard
46
+
47
+ def load_model_flow_standard(ENV,device):
48
+ if ENV=="DEPLOY":
49
+ model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
50
+ else:
51
+ model_path_standard = "outputs/flow_matching/flow_model.pth"
52
+ print("Flow Downloaded!")
53
+ model_flow_standard = ConditionalUNet().to(device)
54
+ model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
55
+ model_flow_standard.eval()
56
+ return model_flow_standard
57
+
58
+ def load_model_flow_localized(ENV,device):
59
+ if ENV=="DEPLOY":
60
+ model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
61
+ else:
62
+ model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
63
+ print("Flow Downloaded!")
64
+ model_flow_localized = ConditionalUNet().to(device)
65
+ model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
66
+ model_flow_localized.eval()
67
+ return model_flow_localized
68
+
69
+
70
  def resize(image,size=(200,200)):
71
  stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
72
  return stretch_near