trying to initialize all weights
Browse files- app.py +3 -3
- model_utils/efficientnet_config.py +6 -3
app.py
CHANGED
@@ -117,11 +117,11 @@ def get_activations(model, image: list, model_name: str,
|
|
117 |
|
118 |
|
119 |
layer_outputs = {}
|
120 |
-
for i in range(len(model.features)):
|
121 |
-
image = model.features[i](image)
|
122 |
layer_outputs[i] = image
|
123 |
print(i, layer_outputs[i].shape)
|
124 |
-
output = model(image).detach().cpu().numpy()
|
125 |
output_1 = activation_indices[model_name].detach().cpu().numpy()
|
126 |
output_2 = activation_indices[model_name].detach().cpu().numpy()
|
127 |
|
|
|
117 |
|
118 |
|
119 |
layer_outputs = {}
|
120 |
+
for i in range(len(model.model.features)):
|
121 |
+
image = model.model.features[i](image)
|
122 |
layer_outputs[i] = image
|
123 |
print(i, layer_outputs[i].shape)
|
124 |
+
output = model.model(image).detach().cpu().numpy()
|
125 |
output_1 = activation_indices[model_name].detach().cpu().numpy()
|
126 |
output_2 = activation_indices[model_name].detach().cpu().numpy()
|
127 |
|
model_utils/efficientnet_config.py
CHANGED
@@ -303,7 +303,7 @@ class EfficientNetPreTrained(PreTrainedModel):
|
|
303 |
config
|
304 |
):
|
305 |
super().__init__(config)
|
306 |
-
self.model = EfficientNet(
|
307 |
num_channels=config.num_channels,
|
308 |
num_classes=config.num_classes,
|
309 |
size=config.size,
|
@@ -320,8 +320,11 @@ class EfficientNetPreTrained(PreTrainedModel):
|
|
320 |
# not all will have weights and biases
|
321 |
try:
|
322 |
module.weight.data.normal_(mean=0.0)
|
323 |
-
except
|
324 |
-
|
|
|
|
|
|
|
325 |
try:
|
326 |
module.bias.data.zero_()
|
327 |
except AttributeError:
|
|
|
303 |
config
|
304 |
):
|
305 |
super().__init__(config)
|
306 |
+
self.model = EfficientNet(dropout=config.dropout,
|
307 |
num_channels=config.num_channels,
|
308 |
num_classes=config.num_classes,
|
309 |
size=config.size,
|
|
|
320 |
# not all will have weights and biases
|
321 |
try:
|
322 |
module.weight.data.normal_(mean=0.0)
|
323 |
+
except Exception:
|
324 |
+
try:
|
325 |
+
module.weight.data.fill_(1.0)
|
326 |
+
except Exception:
|
327 |
+
_ = None
|
328 |
try:
|
329 |
module.bias.data.zero_()
|
330 |
except AttributeError:
|