noumanjavaid commited on
Commit
a64912b
·
verified ·
1 Parent(s): ba1da16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -37
app.py CHANGED
@@ -79,54 +79,58 @@ def predict(Seed, ckpt):
79
  seed_all(int(Seed))
80
 
81
  # --- MODIFICATION START ---
82
- # Try to get the original args from the checkpoint first
83
- # The key might be 'args', 'config', or something similar.
84
- # We need to inspect the actual keys of a loaded ckpt if this doesn't work.
85
  if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
86
  actual_args = ckpt['args']
87
  print("Using 'args' found in checkpoint.")
88
  else:
89
- # Fallback to constructing a mock_args if 'args' is not as expected
90
- # This part needs to be more robust and include all necessary defaults
91
  print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
92
-
93
- # Defaults - these might need to be adjusted based on the original training scripts
94
- # or by inspecting a correctly loaded checkpoint from the original repo.
95
  default_latent_dim = 128
96
- default_hyper = None # Or some sensible default if PointwiseNet/etc. need it
97
- default_residual = True # Common default for PointwiseNet, but needs verification
98
  default_flow_depth = 10
99
  default_flow_hidden_dim = 256
100
- default_model_type = 'gaussian' # Default if not found
101
- default_num_points = 2048
102
- default_flexibility = 0.0
103
-
104
- # Try to get values from ckpt if they exist at the top level
105
- # (some checkpoints might store them flatly instead of under an 'args' key)
106
- model_type = ckpt.get('model', default_model_type) # Check if 'model' key exists directly
107
- latent_dim = ckpt.get('latent_dim', default_latent_dim)
108
- hyper = ckpt.get('hyper', default_hyper)
109
- residual = ckpt.get('residual', default_residual)
110
- flow_depth = ckpt.get('flow_depth', default_flow_depth)
111
- flow_hidden_dim = ckpt.get('flow_hidden_dim', default_flow_hidden_dim)
112
- num_points_to_generate = ckpt.get('num_points', default_num_points)
113
- flexibility = ckpt.get('flexibility', default_flexibility)
114
-
115
- # Create the mock_args object
116
  actual_args = type('Args', (), {
117
- 'model': model_type,
118
- 'latent_dim': latent_dim,
119
- 'hyper': hyper,
120
- 'residual': residual, # Added residual
121
- 'flow_depth': flow_depth,
122
- 'flow_hidden_dim': flow_hidden_dim,
123
- 'num_points': num_points_to_generate,
124
- 'flexibility': flexibility
125
- # Add any other attributes that models might expect from 'args'
126
  })()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # --- MODIFICATION END ---
128
 
129
- # Now use actual_args to instantiate models
130
  if actual_args.model == 'gaussian':
131
  model = GaussianVAE(actual_args).to(device)
132
  elif actual_args.model == 'flow':
@@ -139,6 +143,7 @@ def predict(Seed, ckpt):
139
 
140
  gen_pcs = []
141
  with torch.no_grad():
 
142
  z = torch.randn([1, actual_args.latent_dim]).to(device)
143
  x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
144
  gen_pcs.append(x.detach().cpu())
@@ -147,7 +152,6 @@ def predict(Seed, ckpt):
147
  gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
148
 
149
  return gen_pcs_normalized[0]
150
-
151
  def generate(seed, value):
152
  if value == "Airplane":
153
  ckpt = ckpt_airplane
 
79
  seed_all(int(Seed))
80
 
81
  # --- MODIFICATION START ---
82
+ actual_args = None
 
 
83
  if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
84
  actual_args = ckpt['args']
85
  print("Using 'args' found in checkpoint.")
86
  else:
87
+ # This fallback should ideally not be hit if 'args' is usually present
 
88
  print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
89
+ # Define all necessary defaults if we have to construct from scratch
90
+ default_model_type = 'gaussian'
 
91
  default_latent_dim = 128
92
+ default_hyper = None
93
+ default_residual = True
94
  default_flow_depth = 10
95
  default_flow_hidden_dim = 256
96
+ default_num_points = 2048 # Default for sampling
97
+ default_flexibility = 0.0 # Default for sampling
98
+
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  actual_args = type('Args', (), {
100
+ 'model': ckpt.get('model', default_model_type),
101
+ 'latent_dim': ckpt.get('latent_dim', default_latent_dim),
102
+ 'hyper': ckpt.get('hyper', default_hyper),
103
+ 'residual': ckpt.get('residual', default_residual),
104
+ 'flow_depth': ckpt.get('flow_depth', default_flow_depth),
105
+ 'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
106
+ 'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
107
+ 'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
 
108
  })()
109
+
110
+ # Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
111
+ # These are parameters for the .sample() method, not necessarily model construction.
112
+ # The original training args might not have included these if they were fixed in the sampling script.
113
+
114
+ # Default values for sampling parameters if not present in actual_args
115
+ default_num_points_sampling = 2048
116
+ default_flexibility_sampling = 0.0
117
+
118
+ if not hasattr(actual_args, 'num_points'):
119
+ print(f"Attribute 'num_points' not found in actual_args. Setting default: {default_num_points_sampling}")
120
+ setattr(actual_args, 'num_points', default_num_points_sampling)
121
+
122
+ if not hasattr(actual_args, 'flexibility'):
123
+ print(f"Attribute 'flexibility' not found in actual_args. Setting default: {default_flexibility_sampling}")
124
+ setattr(actual_args, 'flexibility', default_flexibility_sampling)
125
+
126
+ # Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
127
+ # This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
128
+ if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
129
+ print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
130
+ setattr(actual_args, 'residual', True) # Default for GaussianVAE
131
+
132
  # --- MODIFICATION END ---
133
 
 
134
  if actual_args.model == 'gaussian':
135
  model = GaussianVAE(actual_args).to(device)
136
  elif actual_args.model == 'flow':
 
143
 
144
  gen_pcs = []
145
  with torch.no_grad():
146
+ # Use the (potentially now augmented) actual_args for sampling
147
  z = torch.randn([1, actual_args.latent_dim]).to(device)
148
  x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
149
  gen_pcs.append(x.detach().cpu())
 
152
  gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
153
 
154
  return gen_pcs_normalized[0]
 
155
  def generate(seed, value):
156
  if value == "Airplane":
157
  ckpt = ckpt_airplane