broadfield-dev commited on
Commit
b02989a
·
verified ·
1 Parent(s): e16d9dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -172,20 +172,20 @@ def run_inference(input_tensor):
172
  with torch.no_grad():
173
  with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
174
  prediction = model(input_batch)
175
-
176
- # *** FIX: Convert from BFloat16 to Float32 before returning, making it NumPy compatible ***
177
  return prediction.cpu().to(torch.float32)
178
 
179
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
180
  if last_input_map is None: return None, None, None
181
  c_idx = SDO_CHANNELS.index(channel_name)
182
 
183
- scalers_dict = APP_CACHE["scalers"]
184
- scaler = scalers_dict[channel_name]
185
- params = scaler.to_dict()
186
- mean, std, epsilon, sl_scale_factor = params['mean'], params['std'], params['epsilon'], params['sl_scale_factor']
 
 
 
187
 
188
- # The prediction_tensor is now Float32, so .numpy() will work
189
  pred_slice = inverse_transform_single_channel(
190
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
191
  )
 
172
  with torch.no_grad():
173
  with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
174
  prediction = model(input_batch)
 
 
175
  return prediction.cpu().to(torch.float32)
176
 
177
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
178
  if last_input_map is None: return None, None, None
179
  c_idx = SDO_CHANNELS.index(channel_name)
180
 
181
+ # *** FIX: Access the specific scaler for the channel from the dictionary ***
182
+ scaler = APP_CACHE["scalers"][channel_name]
183
+ # *** FIX: Access the parameters as attributes, not from to_dict() ***
184
+ mean = scaler.mean
185
+ std = scaler.std
186
+ epsilon = scaler.epsilon
187
+ sl_scale_factor = scaler.sl_scale_factor
188
 
 
189
  pred_slice = inverse_transform_single_channel(
190
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
191
  )