Shivdutta commited on
Commit
55e8481
Β·
verified Β·
1 Parent(s): ad8bc10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -53,14 +53,39 @@ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
53
  resblock = SimpleResBlock(phi_embed).to(device)
54
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
55
  # Load the model with the appropriate compute_type
 
56
  audio_model_size = "tiny"
 
57
  try:
58
- audio_model = whisperx.load_model(audio_model_size, device, compute_type=compute_type)
 
 
 
 
 
 
 
 
 
 
59
  print(f"Model loaded successfully with compute_type: {compute_type}")
60
  except ValueError as e:
61
  print(f"Error loading model: {e}")
62
- print("Falling back to int8 compute type")
63
- audio_model = whisperx.load_model(audio_model_size, device, compute_type="int8")
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # load weights
66
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
 
53
  resblock = SimpleResBlock(phi_embed).to(device)
54
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
55
  # Load the model with the appropriate compute_type
56
+ # Load the audio model with appropriate compute_type
57
  audio_model_size = "tiny"
58
+ compute_type = "float32" # Ensure using a compatible compute type
59
  try:
60
+ audio_model = whisperx.load_model(
61
+ audio_model_size,
62
+ device,
63
+ compute_type=compute_type,
64
+ # Provide necessary parameters based on your version of whisperx
65
+ # Check documentation for required parameters
66
+ max_new_tokens=100, # Example values, adjust as needed
67
+ clip_timestamps=True,
68
+ hallucination_silence_threshold=0.5,
69
+ hotwords=None # Add specific hotwords if needed
70
+ )
71
  print(f"Model loaded successfully with compute_type: {compute_type}")
72
  except ValueError as e:
73
  print(f"Error loading model: {e}")
74
+ # Optionally, try loading with int8 if necessary
75
+ try:
76
+ audio_model = whisperx.load_model(
77
+ audio_model_size,
78
+ device,
79
+ compute_type="int8",
80
+ max_new_tokens=100,
81
+ clip_timestamps=True,
82
+ hallucination_silence_threshold=0.5,
83
+ hotwords=None
84
+ )
85
+ print("Fell back to int8 compute type successfully.")
86
+ except Exception as e:
87
+ print(f"Failed to load model with int8: {e}")
88
+
89
 
90
  # load weights
91
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')