Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
# Import necessary libraries
|
4 |
import gradio as gr
|
5 |
import torch
|
@@ -10,6 +8,9 @@ import datasets
|
|
10 |
from datasets import load_dataset, DatasetDict, Audio
|
11 |
from huggingface_hub import PyTorchModelHubMixin
|
12 |
|
|
|
|
|
|
|
13 |
# Define data class
|
14 |
class SpeechInferenceDataset(Dataset):
|
15 |
def __init__(self, audio_data, text_processor):
|
@@ -57,7 +58,6 @@ def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"):
|
|
57 |
inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
|
58 |
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
|
59 |
input_features, decoder_input_ids = next(iter(inference_loader))
|
60 |
-
# Replace 'device' with your device configuration (e.g., 'cuda' or 'cpu')
|
61 |
input_features = input_features.squeeze(1).to(device)
|
62 |
decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
|
63 |
return input_features, decoder_input_ids
|
@@ -68,6 +68,8 @@ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labe
|
|
68 |
|
69 |
# Load the model from Hugging Face Hub
|
70 |
model = SpeechClassifier(config)
|
|
|
|
|
71 |
model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
|
72 |
model.eval()
|
73 |
|
@@ -76,16 +78,38 @@ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labe
|
|
76 |
predicted_ids = int(torch.argmax(logits, dim=-1))
|
77 |
return predicted_ids
|
78 |
|
79 |
-
# Gradio Interface function
|
80 |
-
def
|
81 |
with open(uploaded_file.name, "wb") as f:
|
82 |
f.write(uploaded_file.read())
|
83 |
prediction = predict(uploaded_file.name)
|
84 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
85 |
return label
|
86 |
|
87 |
-
#
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Import necessary libraries
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
8 |
from datasets import load_dataset, DatasetDict, Audio
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
10 |
|
11 |
+
# Ensure you have the device setup (cuda or cpu)
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
# Define data class
|
15 |
class SpeechInferenceDataset(Dataset):
|
16 |
def __init__(self, audio_data, text_processor):
|
|
|
58 |
inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
|
59 |
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
|
60 |
input_features, decoder_input_ids = next(iter(inference_loader))
|
|
|
61 |
input_features = input_features.squeeze(1).to(device)
|
62 |
decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
|
63 |
return input_features, decoder_input_ids
|
|
|
68 |
|
69 |
# Load the model from Hugging Face Hub
|
70 |
model = SpeechClassifier(config)
|
71 |
+
model.to(device)
|
72 |
+
# Use the correct method to load your model (this is an example and may not directly apply)
|
73 |
model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
|
74 |
model.eval()
|
75 |
|
|
|
78 |
predicted_ids = int(torch.argmax(logits, dim=-1))
|
79 |
return predicted_ids
|
80 |
|
81 |
+
# Gradio Interface function for uploaded files
|
82 |
+
def gradio_file_interface(uploaded_file):
|
83 |
with open(uploaded_file.name, "wb") as f:
|
84 |
f.write(uploaded_file.read())
|
85 |
prediction = predict(uploaded_file.name)
|
86 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
87 |
return label
|
88 |
|
89 |
+
# Gradio Interface function for microphone input
|
90 |
+
def gradio_mic_interface(mic_input):
|
91 |
+
prediction = predict(mic_input.name)
|
92 |
+
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
93 |
+
return label
|
94 |
+
|
95 |
+
# Initialize Blocks
|
96 |
+
demo = gr.Blocks()
|
97 |
+
|
98 |
+
# Define the interfaces inside the Blocks context
|
99 |
+
with demo:
|
100 |
+
mic_transcribe = gr.Interface(
|
101 |
+
fn=gradio_mic_interface,
|
102 |
+
inputs=gr.Audio(source="microphone", type="filepath"),
|
103 |
+
outputs=gr.Textbox(label="Prediction")
|
104 |
+
)
|
105 |
+
|
106 |
+
file_transcribe = gr.Interface(
|
107 |
+
fn=gradio_file_interface,
|
108 |
+
inputs=gr.Audio(source="upload", type="filepath"),
|
109 |
+
outputs=gr.Textbox(label="Prediction")
|
110 |
+
)
|
111 |
+
|
112 |
+
gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
|
113 |
+
|
114 |
+
# Launch the demo
|
115 |
+
demo.launch(debug=True)
|