samusander commited on
Commit
cd14b84
·
1 Parent(s): a63c535

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from PIL import Image
2
  import torch
3
  from tqdm.auto import tqdm
4
 
@@ -10,22 +9,45 @@ from point_e.util.plotting import plot_point_cloud
10
  import streamlit as st
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
13
  st.write('creating base model...')
14
- base_name = 'base40M' # use base300M or base1B for better results
15
  base_model = model_from_config(MODEL_CONFIGS[base_name], device)
16
  base_model.eval()
17
  base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
 
18
  st.write('creating upsample model...')
19
  upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
20
  upsampler_model.eval()
21
  upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
 
22
  st.write('downloading base checkpoint...')
23
  base_model.load_state_dict(load_checkpoint(base_name, device))
 
24
  st.write('downloading upsampler checkpoint...')
25
  upsampler_model.load_state_dict(load_checkpoint('upsample', device))
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
29
 
30
 
31
 
 
 
1
  import torch
2
  from tqdm.auto import tqdm
3
 
 
9
  import streamlit as st
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
  st.write('creating base model...')
14
+ base_name = 'base40M-textvec'
15
  base_model = model_from_config(MODEL_CONFIGS[base_name], device)
16
  base_model.eval()
17
  base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
18
+
19
  st.write('creating upsample model...')
20
  upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
21
  upsampler_model.eval()
22
  upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
23
+
24
  st.write('downloading base checkpoint...')
25
  base_model.load_state_dict(load_checkpoint(base_name, device))
26
+
27
  st.write('downloading upsampler checkpoint...')
28
  upsampler_model.load_state_dict(load_checkpoint('upsample', device))
29
 
30
 
31
+ # Define Sampler
32
+ sampler = PointCloudSampler(
33
+ device=device,
34
+ models=[base_model, upsampler_model],
35
+ diffusions=[base_diffusion, upsampler_diffusion],
36
+ num_points=[1024, 4096 - 1024],
37
+ aux_channels=['R', 'G', 'B'],
38
+ guidance_scale=[3.0, 0.0],
39
+ model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
40
+ )
41
+
42
+
43
+ # Load an image to condition on.
44
+ prompt = st.sidebar.text_input("Prompt")
45
+
46
 
47
+ # Produce a sample from the model.
48
+ samples = None
49
+ for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):
50
+ samples = x
51
 
52
 
53