jcho02 commited on
Commit
2c19de2
·
verified ·
1 Parent(s): 6d9e5fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
-
3
  # Import necessary libraries
4
  import gradio as gr
5
  import torch
@@ -10,6 +8,9 @@ import datasets
10
  from datasets import load_dataset, DatasetDict, Audio
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
 
 
 
13
  # Define data class
14
  class SpeechInferenceDataset(Dataset):
15
  def __init__(self, audio_data, text_processor):
@@ -57,7 +58,6 @@ def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"):
57
  inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
58
  inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
59
  input_features, decoder_input_ids = next(iter(inference_loader))
60
- # Replace 'device' with your device configuration (e.g., 'cuda' or 'cpu')
61
  input_features = input_features.squeeze(1).to(device)
62
  decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
63
  return input_features, decoder_input_ids
@@ -68,6 +68,8 @@ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labe
68
 
69
  # Load the model from Hugging Face Hub
70
  model = SpeechClassifier(config)
 
 
71
  model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
72
  model.eval()
73
 
@@ -76,16 +78,38 @@ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labe
76
  predicted_ids = int(torch.argmax(logits, dim=-1))
77
  return predicted_ids
78
 
79
- # Gradio Interface function
80
- def gradio_interface(uploaded_file):
81
  with open(uploaded_file.name, "wb") as f:
82
  f.write(uploaded_file.read())
83
  prediction = predict(uploaded_file.name)
84
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
85
  return label
86
 
87
- # Create and launch Gradio Interface with File upload input
88
- iface = gr.Interface(fn=gradio_interface,
89
- inputs=gr.inputs.File(label="Upload Audio File"),
90
- outputs="text")
91
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Import necessary libraries
2
  import gradio as gr
3
  import torch
 
8
  from datasets import load_dataset, DatasetDict, Audio
9
  from huggingface_hub import PyTorchModelHubMixin
10
 
11
+ # Ensure you have the device setup (cuda or cpu)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
  # Define data class
15
  class SpeechInferenceDataset(Dataset):
16
  def __init__(self, audio_data, text_processor):
 
58
  inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
59
  inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
60
  input_features, decoder_input_ids = next(iter(inference_loader))
 
61
  input_features = input_features.squeeze(1).to(device)
62
  decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
63
  return input_features, decoder_input_ids
 
68
 
69
  # Load the model from Hugging Face Hub
70
  model = SpeechClassifier(config)
71
+ model.to(device)
72
+ # Use the correct method to load your model (this is an example and may not directly apply)
73
  model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
74
  model.eval()
75
 
 
78
  predicted_ids = int(torch.argmax(logits, dim=-1))
79
  return predicted_ids
80
 
81
+ # Gradio Interface function for uploaded files
82
+ def gradio_file_interface(uploaded_file):
83
  with open(uploaded_file.name, "wb") as f:
84
  f.write(uploaded_file.read())
85
  prediction = predict(uploaded_file.name)
86
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
87
  return label
88
 
89
+ # Gradio Interface function for microphone input
90
+ def gradio_mic_interface(mic_input):
91
+ prediction = predict(mic_input.name)
92
+ label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
93
+ return label
94
+
95
+ # Initialize Blocks
96
+ demo = gr.Blocks()
97
+
98
+ # Define the interfaces inside the Blocks context
99
+ with demo:
100
+ mic_transcribe = gr.Interface(
101
+ fn=gradio_mic_interface,
102
+ inputs=gr.Audio(source="microphone", type="filepath"),
103
+ outputs=gr.Textbox(label="Prediction")
104
+ )
105
+
106
+ file_transcribe = gr.Interface(
107
+ fn=gradio_file_interface,
108
+ inputs=gr.Audio(source="upload", type="filepath"),
109
+ outputs=gr.Textbox(label="Prediction")
110
+ )
111
+
112
+ gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
113
+
114
+ # Launch the demo
115
+ demo.launch(debug=True)