Update app.py
Browse files
app.py
CHANGED
@@ -79,54 +79,58 @@ def predict(Seed, ckpt):
|
|
79 |
seed_all(int(Seed))
|
80 |
|
81 |
# --- MODIFICATION START ---
|
82 |
-
|
83 |
-
# The key might be 'args', 'config', or something similar.
|
84 |
-
# We need to inspect the actual keys of a loaded ckpt if this doesn't work.
|
85 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
86 |
actual_args = ckpt['args']
|
87 |
print("Using 'args' found in checkpoint.")
|
88 |
else:
|
89 |
-
#
|
90 |
-
# This part needs to be more robust and include all necessary defaults
|
91 |
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
|
92 |
-
|
93 |
-
|
94 |
-
# or by inspecting a correctly loaded checkpoint from the original repo.
|
95 |
default_latent_dim = 128
|
96 |
-
default_hyper = None
|
97 |
-
default_residual = True
|
98 |
default_flow_depth = 10
|
99 |
default_flow_hidden_dim = 256
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
# Try to get values from ckpt if they exist at the top level
|
105 |
-
# (some checkpoints might store them flatly instead of under an 'args' key)
|
106 |
-
model_type = ckpt.get('model', default_model_type) # Check if 'model' key exists directly
|
107 |
-
latent_dim = ckpt.get('latent_dim', default_latent_dim)
|
108 |
-
hyper = ckpt.get('hyper', default_hyper)
|
109 |
-
residual = ckpt.get('residual', default_residual)
|
110 |
-
flow_depth = ckpt.get('flow_depth', default_flow_depth)
|
111 |
-
flow_hidden_dim = ckpt.get('flow_hidden_dim', default_flow_hidden_dim)
|
112 |
-
num_points_to_generate = ckpt.get('num_points', default_num_points)
|
113 |
-
flexibility = ckpt.get('flexibility', default_flexibility)
|
114 |
-
|
115 |
-
# Create the mock_args object
|
116 |
actual_args = type('Args', (), {
|
117 |
-
'model':
|
118 |
-
'latent_dim': latent_dim,
|
119 |
-
'hyper': hyper,
|
120 |
-
'residual': residual,
|
121 |
-
'flow_depth': flow_depth,
|
122 |
-
'flow_hidden_dim': flow_hidden_dim,
|
123 |
-
'num_points':
|
124 |
-
'flexibility': flexibility
|
125 |
-
# Add any other attributes that models might expect from 'args'
|
126 |
})()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
# --- MODIFICATION END ---
|
128 |
|
129 |
-
# Now use actual_args to instantiate models
|
130 |
if actual_args.model == 'gaussian':
|
131 |
model = GaussianVAE(actual_args).to(device)
|
132 |
elif actual_args.model == 'flow':
|
@@ -139,6 +143,7 @@ def predict(Seed, ckpt):
|
|
139 |
|
140 |
gen_pcs = []
|
141 |
with torch.no_grad():
|
|
|
142 |
z = torch.randn([1, actual_args.latent_dim]).to(device)
|
143 |
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
144 |
gen_pcs.append(x.detach().cpu())
|
@@ -147,7 +152,6 @@ def predict(Seed, ckpt):
|
|
147 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
148 |
|
149 |
return gen_pcs_normalized[0]
|
150 |
-
|
151 |
def generate(seed, value):
|
152 |
if value == "Airplane":
|
153 |
ckpt = ckpt_airplane
|
|
|
79 |
seed_all(int(Seed))
|
80 |
|
81 |
# --- MODIFICATION START ---
|
82 |
+
actual_args = None
|
|
|
|
|
83 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
84 |
actual_args = ckpt['args']
|
85 |
print("Using 'args' found in checkpoint.")
|
86 |
else:
|
87 |
+
# This fallback should ideally not be hit if 'args' is usually present
|
|
|
88 |
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
|
89 |
+
# Define all necessary defaults if we have to construct from scratch
|
90 |
+
default_model_type = 'gaussian'
|
|
|
91 |
default_latent_dim = 128
|
92 |
+
default_hyper = None
|
93 |
+
default_residual = True
|
94 |
default_flow_depth = 10
|
95 |
default_flow_hidden_dim = 256
|
96 |
+
default_num_points = 2048 # Default for sampling
|
97 |
+
default_flexibility = 0.0 # Default for sampling
|
98 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
actual_args = type('Args', (), {
|
100 |
+
'model': ckpt.get('model', default_model_type),
|
101 |
+
'latent_dim': ckpt.get('latent_dim', default_latent_dim),
|
102 |
+
'hyper': ckpt.get('hyper', default_hyper),
|
103 |
+
'residual': ckpt.get('residual', default_residual),
|
104 |
+
'flow_depth': ckpt.get('flow_depth', default_flow_depth),
|
105 |
+
'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
|
106 |
+
'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
|
107 |
+
'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
|
|
|
108 |
})()
|
109 |
+
|
110 |
+
# Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
|
111 |
+
# These are parameters for the .sample() method, not necessarily model construction.
|
112 |
+
# The original training args might not have included these if they were fixed in the sampling script.
|
113 |
+
|
114 |
+
# Default values for sampling parameters if not present in actual_args
|
115 |
+
default_num_points_sampling = 2048
|
116 |
+
default_flexibility_sampling = 0.0
|
117 |
+
|
118 |
+
if not hasattr(actual_args, 'num_points'):
|
119 |
+
print(f"Attribute 'num_points' not found in actual_args. Setting default: {default_num_points_sampling}")
|
120 |
+
setattr(actual_args, 'num_points', default_num_points_sampling)
|
121 |
+
|
122 |
+
if not hasattr(actual_args, 'flexibility'):
|
123 |
+
print(f"Attribute 'flexibility' not found in actual_args. Setting default: {default_flexibility_sampling}")
|
124 |
+
setattr(actual_args, 'flexibility', default_flexibility_sampling)
|
125 |
+
|
126 |
+
# Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
|
127 |
+
# This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
|
128 |
+
if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
|
129 |
+
print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
|
130 |
+
setattr(actual_args, 'residual', True) # Default for GaussianVAE
|
131 |
+
|
132 |
# --- MODIFICATION END ---
|
133 |
|
|
|
134 |
if actual_args.model == 'gaussian':
|
135 |
model = GaussianVAE(actual_args).to(device)
|
136 |
elif actual_args.model == 'flow':
|
|
|
143 |
|
144 |
gen_pcs = []
|
145 |
with torch.no_grad():
|
146 |
+
# Use the (potentially now augmented) actual_args for sampling
|
147 |
z = torch.randn([1, actual_args.latent_dim]).to(device)
|
148 |
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
149 |
gen_pcs.append(x.detach().cpu())
|
|
|
152 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
153 |
|
154 |
return gen_pcs_normalized[0]
|
|
|
155 |
def generate(seed, value):
|
156 |
if value == "Airplane":
|
157 |
ckpt = ckpt_airplane
|