noumanjavaid commited on
Commit
ba1da16
·
verified ·
1 Parent(s): ada67da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -76,55 +76,75 @@ def normalize_point_clouds(pcs, mode):
76
  def predict(Seed, ckpt):
77
  if Seed is None:
78
  Seed = 777
79
- seed_all(int(Seed)) # Ensure Seed is an integer
80
-
81
- # Ensure args is accessible, provide a default if it's missing or not a Namespace
82
- # This is a defensive measure, as the error was about loading argparse.Namespace
83
- if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'):
84
- # This case should ideally not happen if the checkpoint is valid
85
- # but if it does, we need a fallback or error.
86
- # For now, let's assume 'args' and 'args.model' exist based on the error.
87
- print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.")
88
- model_type = 'gaussian'
89
- latent_dim = ckpt.get('latent_dim', 128) # A common default
90
- flexibility = ckpt.get('flexibility', 0.0) # A common default
91
  else:
92
- model_type = ckpt['args'].model
93
- latent_dim = ckpt['args'].latent_dim
94
- flexibility = ckpt['args'].flexibility
95
-
96
-
97
- if model_type == 'gaussian':
98
- # Pass necessary args to the constructor
99
- # We need to mock an args object if ckpt['args'] wasn't a full argparse.Namespace
100
- # or if some attributes are missing.
101
- mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() # Add other required args
102
- model = GaussianVAE(mock_args).to(device)
103
- elif model_type == 'flow':
104
- mock_args = type('Args', (), {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  'latent_dim': latent_dim,
106
- 'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), # Example default
107
- 'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), # Example default
108
- 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)
 
 
 
 
109
  })()
110
- model = FlowVAE(mock_args).to(device)
 
 
 
 
 
 
111
  else:
112
- raise ValueError(f"Unknown model type: {model_type}")
113
 
114
  model.load_state_dict(ckpt['state_dict'])
115
- model.eval() # Set model to evaluation mode
116
 
117
- # Generate Point Clouds
118
  gen_pcs = []
119
  with torch.no_grad():
120
- z = torch.randn([1, latent_dim]).to(device)
121
- # The sample method might also depend on args from the checkpoint
122
- num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) # Default to 2048 if not in args
123
- x = model.sample(z, num_points_to_generate, flexibility=flexibility)
124
  gen_pcs.append(x.detach().cpu())
125
-
126
- gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] # Ensure we take only one point cloud
127
- gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") # Use .clone() if normalize_point_clouds modifies inplace
128
 
129
  return gen_pcs_normalized[0]
130
 
 
76
  def predict(Seed, ckpt):
77
  if Seed is None:
78
  Seed = 777
79
+ seed_all(int(Seed))
80
+
81
+ # --- MODIFICATION START ---
82
+ # Try to get the original args from the checkpoint first
83
+ # The key might be 'args', 'config', or something similar.
84
+ # We need to inspect the actual keys of a loaded ckpt if this doesn't work.
85
+ if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
86
+ actual_args = ckpt['args']
87
+ print("Using 'args' found in checkpoint.")
 
 
 
88
  else:
89
+ # Fallback to constructing a mock_args if 'args' is not as expected
90
+ # This part needs to be more robust and include all necessary defaults
91
+ print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
92
+
93
+ # Defaults - these might need to be adjusted based on the original training scripts
94
+ # or by inspecting a correctly loaded checkpoint from the original repo.
95
+ default_latent_dim = 128
96
+ default_hyper = None # Or some sensible default if PointwiseNet/etc. need it
97
+ default_residual = True # Common default for PointwiseNet, but needs verification
98
+ default_flow_depth = 10
99
+ default_flow_hidden_dim = 256
100
+ default_model_type = 'gaussian' # Default if not found
101
+ default_num_points = 2048
102
+ default_flexibility = 0.0
103
+
104
+ # Try to get values from ckpt if they exist at the top level
105
+ # (some checkpoints might store them flatly instead of under an 'args' key)
106
+ model_type = ckpt.get('model', default_model_type) # Check if 'model' key exists directly
107
+ latent_dim = ckpt.get('latent_dim', default_latent_dim)
108
+ hyper = ckpt.get('hyper', default_hyper)
109
+ residual = ckpt.get('residual', default_residual)
110
+ flow_depth = ckpt.get('flow_depth', default_flow_depth)
111
+ flow_hidden_dim = ckpt.get('flow_hidden_dim', default_flow_hidden_dim)
112
+ num_points_to_generate = ckpt.get('num_points', default_num_points)
113
+ flexibility = ckpt.get('flexibility', default_flexibility)
114
+
115
+ # Create the mock_args object
116
+ actual_args = type('Args', (), {
117
+ 'model': model_type,
118
  'latent_dim': latent_dim,
119
+ 'hyper': hyper,
120
+ 'residual': residual, # Added residual
121
+ 'flow_depth': flow_depth,
122
+ 'flow_hidden_dim': flow_hidden_dim,
123
+ 'num_points': num_points_to_generate,
124
+ 'flexibility': flexibility
125
+ # Add any other attributes that models might expect from 'args'
126
  })()
127
+ # --- MODIFICATION END ---
128
+
129
+ # Now use actual_args to instantiate models
130
+ if actual_args.model == 'gaussian':
131
+ model = GaussianVAE(actual_args).to(device)
132
+ elif actual_args.model == 'flow':
133
+ model = FlowVAE(actual_args).to(device)
134
  else:
135
+ raise ValueError(f"Unknown model type: {actual_args.model}")
136
 
137
  model.load_state_dict(ckpt['state_dict'])
138
+ model.eval()
139
 
 
140
  gen_pcs = []
141
  with torch.no_grad():
142
+ z = torch.randn([1, actual_args.latent_dim]).to(device)
143
+ x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
 
 
144
  gen_pcs.append(x.detach().cpu())
145
+
146
+ gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
147
+ gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
148
 
149
  return gen_pcs_normalized[0]
150