JuanJoseMV commited on
Commit
3f5f788
·
1 Parent(s): 8f96165

load classifier weights

Browse files
.gitignore CHANGED
@@ -17,5 +17,3 @@ build/
17
  # VSCode
18
  .vscode/
19
  *.code-workspace
20
-
21
- behaviour_model/
 
17
  # VSCode
18
  .vscode/
19
  *.code-workspace
 
 
app.py CHANGED
@@ -24,7 +24,7 @@ def create_demo():
24
  emotion_model.eval()
25
 
26
  behaviour_model = get_behaviour_model(
27
- behaviour_model_path="behaviour_model/",
28
  device=device,
29
  )
30
 
 
24
  emotion_model.eval()
25
 
26
  behaviour_model = get_behaviour_model(
27
+ behaviour_model_path="classifier_weights.bin",
28
  device=device,
29
  )
30
 
src/audio_processor.py CHANGED
@@ -104,10 +104,22 @@ class AudioProcessor:
104
  str: The predicted emotion label.
105
 
106
  """
 
 
 
 
 
 
 
107
 
108
  print("Segmenting audio...")
109
  out = self.segmentation_model(
110
- inputs=audio_path,
 
 
 
 
 
111
  return_timestamps=True,
112
  )
113
 
 
104
  str: The predicted emotion label.
105
 
106
  """
107
+ try:
108
+ input_frames, _ = librosa.load(
109
+ audio_path,
110
+ sr=SAMPLING_RATE
111
+ )
112
+ except Exception as e:
113
+ gr.Error(f"Error loading audio file: {e}.")
114
 
115
  print("Segmenting audio...")
116
  out = self.segmentation_model(
117
+ inputs={
118
+ "raw": input_frames,
119
+ "sampling_rate": SAMPLING_RATE,
120
+ },
121
+ chunk_length_s=30,
122
+ stride_length_s=5,
123
  return_timestamps=True,
124
  )
125
 
src/model/behaviour_model.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import argparse
3
  import torch
4
  from .make_model import make_model
@@ -20,9 +19,10 @@ hparams_dict = {
20
  }
21
  hparams = argparse.Namespace(**hparams_dict)
22
 
23
- def get_behaviour_model(behaviour_model_path, device):
24
- state_dict = torch.load(os.path.join(behaviour_model_path, 'pytorch_model.bin'), map_location=device)
25
  model = make_model(hparams)
26
- model.load_state_dict(state_dict)
 
27
 
28
  return model
 
 
1
  import argparse
2
  import torch
3
  from .make_model import make_model
 
19
  }
20
  hparams = argparse.Namespace(**hparams_dict)
21
 
22
+ def get_behaviour_model(classifier_weights_path, device):
23
+ state_dict = torch.load(classifier_weights_path, map_location=device)
24
  model = make_model(hparams)
25
+ model.classifier.load_state_dict(state_dict)
26
+ model.eval()
27
 
28
  return model
src/model/classifier_weights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c77df5b5cd060698cf2ad93cef2f1b23795ef2faebc0860a5acddc6d87d47b3
3
+ size 1846682