Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
1c789c0
1
Parent(s):
1ab0cdf
added validation for wav and pcm
Browse files- .gitignore +1 -0
- client.py +78 -23
- infer.py +51 -6
- poetry.lock +22 -1
- pyproject.toml +2 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.wav
|
client.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
|
|
|
| 3 |
import wave
|
| 4 |
|
| 5 |
import websockets
|
|
@@ -9,8 +10,62 @@ import ssl
|
|
| 9 |
# Parameters for reading and sending the audio
|
| 10 |
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
async def send_audio(websocket):
|
| 13 |
buffer_size = 1024 * 16 # Send smaller chunks (16KB) for real-time processing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Download the WAV file locally
|
| 16 |
# with requests.get(AUDIO_FILE_URL, stream=True) as response:
|
|
@@ -21,29 +76,29 @@ async def send_audio(websocket):
|
|
| 21 |
# print("Audio file downloaded successfully.")
|
| 22 |
|
| 23 |
# Open the downloaded WAV file and extract PCM data
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
else:
|
| 46 |
-
|
| 47 |
|
| 48 |
|
| 49 |
async def receive_transcription(websocket):
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
+
import logging
|
| 4 |
import wave
|
| 5 |
|
| 6 |
import websockets
|
|
|
|
| 10 |
# Parameters for reading and sending the audio
|
| 11 |
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
|
| 12 |
|
| 13 |
+
from pydub import AudioSegment
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Convert and resample audio before writing it to WAV
|
| 17 |
+
# Convert and resample audio before writing it to WAV
|
| 18 |
+
def convert_to_mono_16k(audio_file_path):
|
| 19 |
+
logging.info(f"Starting audio conversion to mono and resampling to 16kHz for file: {audio_file_path}")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
# Load the audio file into an AudioSegment object
|
| 23 |
+
audio_segment = AudioSegment.from_file(audio_file_path, format="wav")
|
| 24 |
+
|
| 25 |
+
# Convert the audio to mono and resample it to 16kHz
|
| 26 |
+
audio_segment = audio_segment.set_channels(1).set_frame_rate(16000)
|
| 27 |
+
|
| 28 |
+
logging.info("Audio conversion to mono and 16kHz completed successfully.")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logging.error(f"Error during audio conversion: {e}")
|
| 31 |
+
raise e
|
| 32 |
+
|
| 33 |
+
# Return the modified AudioSegment object
|
| 34 |
+
return audio_segment
|
| 35 |
+
|
| 36 |
+
|
| 37 |
async def send_audio(websocket):
|
| 38 |
buffer_size = 1024 * 16 # Send smaller chunks (16KB) for real-time processing
|
| 39 |
+
logging.info("Converting the audio to mono and 16kHz.")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
converted_audio = convert_to_mono_16k('test_copy.wav')
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logging.error(f"Failed to convert audio: {e}")
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
# Send metadata to the server
|
| 48 |
+
metadata = {
|
| 49 |
+
'sample_rate': 16000, # Resampled rate
|
| 50 |
+
'channels': 1, # Converted to mono
|
| 51 |
+
'sampwidth': 2 # Assuming 16-bit audio
|
| 52 |
+
}
|
| 53 |
+
await websocket.send(json.dumps(metadata))
|
| 54 |
+
logging.info(f"Sent metadata: {metadata}")
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
raw_data = converted_audio.raw_data
|
| 58 |
+
logging.info(f"Starting to send raw PCM audio data. Total data size: {len(raw_data)} bytes.")
|
| 59 |
+
|
| 60 |
+
for i in range(0, len(raw_data), buffer_size):
|
| 61 |
+
pcm_chunk = raw_data[i:i + buffer_size]
|
| 62 |
+
await websocket.send(pcm_chunk) # Send raw PCM data chunk
|
| 63 |
+
logging.info(f"Sent PCM chunk of size {len(pcm_chunk)} bytes.")
|
| 64 |
+
await asyncio.sleep(0.01) # Simulate real-time sending
|
| 65 |
+
|
| 66 |
+
logging.info("Completed sending all audio data.")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logging.error(f"Error while sending audio data: {e}")
|
| 69 |
|
| 70 |
# Download the WAV file locally
|
| 71 |
# with requests.get(AUDIO_FILE_URL, stream=True) as response:
|
|
|
|
| 76 |
# print("Audio file downloaded successfully.")
|
| 77 |
|
| 78 |
# Open the downloaded WAV file and extract PCM data
|
| 79 |
+
# with wave.open('test_copy.wav', 'rb') as wav_file:
|
| 80 |
+
# metadata = {
|
| 81 |
+
# 'sample_rate': wav_file.getframerate(),
|
| 82 |
+
# 'channels': wav_file.getnchannels(),
|
| 83 |
+
# 'sampwidth': wav_file.getsampwidth(),
|
| 84 |
+
# }
|
| 85 |
+
#
|
| 86 |
+
# # Send metadata to the server before sending the audio
|
| 87 |
+
# await websocket.send(json.dumps(metadata))
|
| 88 |
+
# print(f"Sent metadata: {metadata}")
|
| 89 |
+
|
| 90 |
+
# # Send the PCM audio data in chunks
|
| 91 |
+
# while True:
|
| 92 |
+
# pcm_chunk = wav_file.readframes(buffer_size)
|
| 93 |
+
# if not pcm_chunk:
|
| 94 |
+
# break # End of file
|
| 95 |
+
#
|
| 96 |
+
# await websocket.send(pcm_chunk) # Send raw PCM data chunk
|
| 97 |
+
# #print(f"Sent PCM chunk of size {len(pcm_chunk)} bytes.")
|
| 98 |
+
# await asyncio.sleep(0.01) # Simulate real-time sending
|
| 99 |
+
|
| 100 |
+
# else:
|
| 101 |
+
# print(f"Failed to download audio file. Status code: {response.status_code}")
|
| 102 |
|
| 103 |
|
| 104 |
async def receive_transcription(websocket):
|
infer.py
CHANGED
|
@@ -131,9 +131,6 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
|
|
| 131 |
"""
|
| 132 |
Transcribe the audio file and return only the segments that have not been processed yet.
|
| 133 |
|
| 134 |
-
:param audio_file: Path to the growing audio file.
|
| 135 |
-
:param last_transcribed_time: The last time (in seconds) that was transcribed.
|
| 136 |
-
:return: Newly transcribed segments and the updated last transcribed time.
|
| 137 |
"""
|
| 138 |
logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")
|
| 139 |
|
|
@@ -177,6 +174,43 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
|
|
| 177 |
import tempfile
|
| 178 |
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
@app.websocket("/wtranscribe")
|
| 181 |
async def websocket_transcribe(websocket: WebSocket):
|
| 182 |
logging.info("New WebSocket connection request received.")
|
|
@@ -214,6 +248,12 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 214 |
# Accumulate the raw PCM data into the buffer
|
| 215 |
pcm_audio_buffer.extend(audio_chunk)
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
# Estimate the duration of the chunk based on its size
|
| 218 |
chunk_duration = len(audio_chunk) / (sample_rate * channels * sample_width)
|
| 219 |
accumulated_audio_time += chunk_duration
|
|
@@ -233,6 +273,11 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 233 |
wav_file.setframerate(sample_rate)
|
| 234 |
wav_file.writeframes(pcm_audio_buffer)
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
logging.info(f"Temporary WAV file created at {temp_wav_file.name} for transcription.")
|
| 237 |
|
| 238 |
# Log to confirm that the file exists and has the expected size
|
|
@@ -260,9 +305,9 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 260 |
await websocket.send_json(response)
|
| 261 |
|
| 262 |
# Optionally delete the temporary WAV file after processing
|
| 263 |
-
if os.path.exists(temp_wav_file):
|
| 264 |
-
os.remove(temp_wav_file)
|
| 265 |
-
logging.info(f"Temporary WAV file {temp_wav_file} removed.")
|
| 266 |
|
| 267 |
except WebSocketDisconnect:
|
| 268 |
logging.info("WebSocket connection closed by the client.")
|
|
|
|
| 131 |
"""
|
| 132 |
Transcribe the audio file and return only the segments that have not been processed yet.
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
| 135 |
logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")
|
| 136 |
|
|
|
|
| 174 |
import tempfile
|
| 175 |
|
| 176 |
|
| 177 |
+
# Function to verify if the PCM data is valid
|
| 178 |
+
def validate_pcm_data(pcm_audio_buffer, sample_rate, channels, sample_width):
|
| 179 |
+
"""Validates the PCM data buffer to ensure it conforms to the expected format."""
|
| 180 |
+
logging.info(f"Validating PCM data: total size = {len(pcm_audio_buffer)} bytes.")
|
| 181 |
+
|
| 182 |
+
# Calculate the expected sample size
|
| 183 |
+
expected_sample_size = sample_rate * channels * sample_width
|
| 184 |
+
actual_sample_size = len(pcm_audio_buffer)
|
| 185 |
+
|
| 186 |
+
if actual_sample_size == 0:
|
| 187 |
+
logging.error("Received PCM data is empty.")
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
logging.info(f"Expected sample size per second: {expected_sample_size} bytes.")
|
| 191 |
+
|
| 192 |
+
if actual_sample_size % expected_sample_size != 0:
|
| 193 |
+
logging.warning(
|
| 194 |
+
f"PCM data size {actual_sample_size} is not a multiple of the expected sample size per second ({expected_sample_size} bytes). Data may be corrupted or incomplete.")
|
| 195 |
+
|
| 196 |
+
return True
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Function to validate if the created WAV file is valid
|
| 200 |
+
def validate_wav_file(wav_file_path):
|
| 201 |
+
"""Validates if the WAV file was created correctly and can be opened."""
|
| 202 |
+
try:
|
| 203 |
+
with wave.open(wav_file_path, 'rb') as wav_file:
|
| 204 |
+
sample_rate = wav_file.getframerate()
|
| 205 |
+
channels = wav_file.getnchannels()
|
| 206 |
+
sample_width = wav_file.getsampwidth()
|
| 207 |
+
logging.info(
|
| 208 |
+
f"WAV file details - Sample Rate: {sample_rate}, Channels: {channels}, Sample Width: {sample_width}")
|
| 209 |
+
return True
|
| 210 |
+
except wave.Error as e:
|
| 211 |
+
logging.error(f"Error reading WAV file: {e}")
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
@app.websocket("/wtranscribe")
|
| 215 |
async def websocket_transcribe(websocket: WebSocket):
|
| 216 |
logging.info("New WebSocket connection request received.")
|
|
|
|
| 248 |
# Accumulate the raw PCM data into the buffer
|
| 249 |
pcm_audio_buffer.extend(audio_chunk)
|
| 250 |
|
| 251 |
+
# Validate the PCM data after each chunk
|
| 252 |
+
if not validate_pcm_data(pcm_audio_buffer, sample_rate, channels, sample_width):
|
| 253 |
+
logging.error("Invalid PCM data received. Aborting transcription.")
|
| 254 |
+
await websocket.send_json({"error": "Invalid PCM data received."})
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
# Estimate the duration of the chunk based on its size
|
| 258 |
chunk_duration = len(audio_chunk) / (sample_rate * channels * sample_width)
|
| 259 |
accumulated_audio_time += chunk_duration
|
|
|
|
| 273 |
wav_file.setframerate(sample_rate)
|
| 274 |
wav_file.writeframes(pcm_audio_buffer)
|
| 275 |
|
| 276 |
+
if not validate_wav_file(temp_wav_file.name):
|
| 277 |
+
logging.error(f"Invalid WAV file created: {temp_wav_file.name}")
|
| 278 |
+
await websocket.send_json({"error": "Invalid WAV file created."})
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
logging.info(f"Temporary WAV file created at {temp_wav_file.name} for transcription.")
|
| 282 |
|
| 283 |
# Log to confirm that the file exists and has the expected size
|
|
|
|
| 305 |
await websocket.send_json(response)
|
| 306 |
|
| 307 |
# Optionally delete the temporary WAV file after processing
|
| 308 |
+
if os.path.exists(temp_wav_file.name):
|
| 309 |
+
os.remove(temp_wav_file.name)
|
| 310 |
+
logging.info(f"Temporary WAV file {temp_wav_file.name} removed.")
|
| 311 |
|
| 312 |
except WebSocketDisconnect:
|
| 313 |
logging.info("WebSocket connection closed by the client.")
|
poetry.lock
CHANGED
|
@@ -1064,6 +1064,16 @@ tokenizers = ">=0.13,<1"
|
|
| 1064 |
conversion = ["transformers[torch] (>=4.23)"]
|
| 1065 |
dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
|
| 1066 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1067 |
[[package]]
|
| 1068 |
name = "filelock"
|
| 1069 |
version = "3.16.0"
|
|
@@ -2539,6 +2549,17 @@ azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0
|
|
| 2539 |
toml = ["tomli (>=2.0.1)"]
|
| 2540 |
yaml = ["pyyaml (>=6.0.1)"]
|
| 2541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2542 |
[[package]]
|
| 2543 |
name = "pygments"
|
| 2544 |
version = "2.18.0"
|
|
@@ -3862,4 +3883,4 @@ type = ["pytest-mypy"]
|
|
| 3862 |
[metadata]
|
| 3863 |
lock-version = "2.0"
|
| 3864 |
python-versions = "3.9.1"
|
| 3865 |
-
content-hash = "
|
|
|
|
| 1064 |
conversion = ["transformers[torch] (>=4.23)"]
|
| 1065 |
dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
|
| 1066 |
|
| 1067 |
+
[[package]]
|
| 1068 |
+
name = "ffmpeg"
|
| 1069 |
+
version = "1.4"
|
| 1070 |
+
description = "ffmpeg python package url [https://github.com/jiashaokun/ffmpeg]"
|
| 1071 |
+
optional = false
|
| 1072 |
+
python-versions = "*"
|
| 1073 |
+
files = [
|
| 1074 |
+
{file = "ffmpeg-1.4.tar.gz", hash = "sha256:6931692c890ff21d39938433c2189747815dca0c60ddc7f9bb97f199dba0b5b9"},
|
| 1075 |
+
]
|
| 1076 |
+
|
| 1077 |
[[package]]
|
| 1078 |
name = "filelock"
|
| 1079 |
version = "3.16.0"
|
|
|
|
| 2549 |
toml = ["tomli (>=2.0.1)"]
|
| 2550 |
yaml = ["pyyaml (>=6.0.1)"]
|
| 2551 |
|
| 2552 |
+
[[package]]
|
| 2553 |
+
name = "pydub"
|
| 2554 |
+
version = "0.25.1"
|
| 2555 |
+
description = "Manipulate audio with an simple and easy high level interface"
|
| 2556 |
+
optional = false
|
| 2557 |
+
python-versions = "*"
|
| 2558 |
+
files = [
|
| 2559 |
+
{file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"},
|
| 2560 |
+
{file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"},
|
| 2561 |
+
]
|
| 2562 |
+
|
| 2563 |
[[package]]
|
| 2564 |
name = "pygments"
|
| 2565 |
version = "2.18.0"
|
|
|
|
| 3883 |
[metadata]
|
| 3884 |
lock-version = "2.0"
|
| 3885 |
python-versions = "3.9.1"
|
| 3886 |
+
content-hash = "62e30245d9470305f2b33ff86655c5a38e9f58c708b7ffb3cdfbf932ccfda6c7"
|
pyproject.toml
CHANGED
|
@@ -24,6 +24,8 @@ openai = "^1.42.0"
|
|
| 24 |
numpy = "^1.22.0"
|
| 25 |
torch = "2.1.0"
|
| 26 |
sounddevice = "^0.5.0"
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
|
|
|
|
| 24 |
numpy = "^1.22.0"
|
| 25 |
torch = "2.1.0"
|
| 26 |
sounddevice = "^0.5.0"
|
| 27 |
+
pydub = "^0.25.1"
|
| 28 |
+
ffmpeg = "^1.4"
|
| 29 |
|
| 30 |
|
| 31 |
|