muhtasham commited on
Commit
a1d6c0c
·
1 Parent(s): dbe4a4a
Files changed (2) hide show
  1. app.py +128 -19
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,19 +1,30 @@
1
- import spaces
2
  import torch
3
  import gradio as gr
4
- from transformers import pipeline
5
  import subprocess
6
- from loguru import logger
7
  import datetime
8
  import tempfile
9
- import os
10
- import json
11
- from pathlib import Path
12
 
13
  MODEL_NAME = "muhtasham/whisper-tg"
14
 
15
  def format_time(seconds):
16
- """Convert seconds to SRT time format (HH:MM:SS,mmm)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  td = datetime.timedelta(seconds=float(seconds))
18
  hours = td.seconds // 3600
19
  minutes = (td.seconds % 3600) // 60
@@ -22,7 +33,35 @@ def format_time(seconds):
22
  return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
23
 
24
  def generate_srt(chunks):
25
- """Generate SRT format subtitles from chunks"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  srt_content = []
27
  for i, chunk in enumerate(chunks, 1):
28
  start_time = format_time(chunk["timestamp"][0])
@@ -32,7 +71,20 @@ def generate_srt(chunks):
32
  return "".join(srt_content)
33
 
34
  def save_srt_to_file(srt_content):
35
- """Save SRT content to a temporary file and return the file path"""
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if not srt_content:
37
  return None
38
 
@@ -54,33 +106,81 @@ def check_ffmpeg():
54
  # Initialize ffmpeg check
55
  check_ffmpeg()
56
 
57
- device = 0 if torch.cuda.is_available() else "cpu"
 
58
  logger.info(f"Using device: {device}")
59
 
60
- def create_pipeline(chunk_length_s):
61
- """Create a new pipeline with specified chunk length"""
 
 
 
 
62
  return pipeline(
63
  task="automatic-speech-recognition",
64
  model=MODEL_NAME,
65
- chunk_length_s=chunk_length_s,
66
  device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
- # Initialize default pipeline
70
- pipe = create_pipeline(30)
71
  logger.info(f"Pipeline initialized: {pipe}")
72
 
73
- @spaces.GPU
74
  def transcribe(inputs, return_timestamps, generate_subs, batch_size, chunk_length_s):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if inputs is None:
76
  logger.warning("No audio file submitted")
77
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
78
 
79
  try:
80
  logger.info(f"Processing audio file: {inputs}")
81
- # Create new pipeline with specified chunk length
82
- current_pipe = create_pipeline(chunk_length_s)
83
- result = current_pipe(inputs, batch_size=batch_size, return_timestamps=return_timestamps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  logger.debug(f"Pipeline result: {result}")
85
 
86
  # Format response as JSON
@@ -121,8 +221,17 @@ def transcribe(inputs, return_timestamps, generate_subs, batch_size, chunk_lengt
121
  srt_file = save_srt_to_file(srt_content)
122
  logger.info("SRT subtitles generated successfully")
123
 
 
 
 
 
 
124
  return formatted_result, srt_file, "" # Return empty string for correction textbox
125
  except Exception as e:
 
 
 
 
126
  logger.exception(f"Error during transcription: {str(e)}")
127
  raise gr.Error(f"Failed to transcribe audio: {str(e)}")
128
 
 
 
1
  import torch
2
  import gradio as gr
 
3
  import subprocess
 
4
  import datetime
5
  import tempfile
6
+ from transformers import pipeline
7
+ from loguru import logger
 
8
 
9
  MODEL_NAME = "muhtasham/whisper-tg"
10
 
11
  def format_time(seconds):
12
+ """Convert seconds to SRT time format (HH:MM:SS,mmm).
13
+
14
+ Args:
15
+ seconds (float): Time in seconds to convert.
16
+
17
+ Returns:
18
+ str: Time formatted as HH:MM:SS,mmm where:
19
+ - HH: Hours (00-99)
20
+ - MM: Minutes (00-59)
21
+ - SS: Seconds (00-59)
22
+ - mmm: Milliseconds (000-999)
23
+
24
+ Example:
25
+ >>> format_time(3661.5)
26
+ '01:01:01,500'
27
+ """
28
  td = datetime.timedelta(seconds=float(seconds))
29
  hours = td.seconds // 3600
30
  minutes = (td.seconds % 3600) // 60
 
33
  return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
34
 
35
  def generate_srt(chunks):
36
+ """Generate SRT format subtitles from transcription chunks.
37
+
38
+ Args:
39
+ chunks (list): List of dictionaries containing transcription chunks.
40
+ Each chunk must have:
41
+ - "timestamp": List of [start_time, end_time] in seconds
42
+ - "text": The transcribed text for that time segment
43
+
44
+ Returns:
45
+ str: SRT formatted subtitles string with format:
46
+ ```
47
+ 1
48
+ HH:MM:SS,mmm --> HH:MM:SS,mmm
49
+ Text content
50
+
51
+ 2
52
+ HH:MM:SS,mmm --> HH:MM:SS,mmm
53
+ Text content
54
+ ...
55
+ ```
56
+
57
+ Example:
58
+ >>> chunks = [
59
+ ... {"timestamp": [0.0, 1.5], "text": "Hello"},
60
+ ... {"timestamp": [1.5, 3.0], "text": "World"}
61
+ ... ]
62
+ >>> generate_srt(chunks)
63
+ '1\\n00:00:00,000 --> 00:00:01,500\\nHello\\n\\n2\\n00:00:01,500 --> 00:00:03,000\\nWorld\\n\\n'
64
+ """
65
  srt_content = []
66
  for i, chunk in enumerate(chunks, 1):
67
  start_time = format_time(chunk["timestamp"][0])
 
71
  return "".join(srt_content)
72
 
73
  def save_srt_to_file(srt_content):
74
+ """Save SRT content to a temporary file.
75
+
76
+ Args:
77
+ srt_content (str): The SRT formatted subtitles content to save.
78
+
79
+ Returns:
80
+ str or None: Path to the temporary file if content was saved,
81
+ None if srt_content was empty.
82
+
83
+ Note:
84
+ The temporary file is created with delete=False to allow it to be
85
+ used after the function returns. The file should be deleted by the
86
+ caller when no longer needed.
87
+ """
88
  if not srt_content:
89
  return None
90
 
 
106
  # Initialize ffmpeg check
107
  check_ffmpeg()
108
 
109
+ # Use T4 GPU if available, otherwise fallback to CPU
110
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
111
  logger.info(f"Using device: {device}")
112
 
113
+ def create_pipeline():
114
+ """Create a new pipeline with optimized settings for T4 GPU.
115
+
116
+ Returns:
117
+ transformers.Pipeline: Configured speech recognition pipeline.
118
+ """
119
  return pipeline(
120
  task="automatic-speech-recognition",
121
  model=MODEL_NAME,
 
122
  device=device,
123
+ torch_dtype=torch.float16, # Use float16 for better performance on T4
124
+ framework="pt", # Explicitly use PyTorch
125
+ return_timestamps=True, # Always return timestamps for better control
126
+ generate_kwargs={
127
+ "task": "transcribe", # Explicitly set transcription task
128
+ "language": "tg", # Default to Tajik
129
+ "condition_on_previous_text": True, # Use context from previous chunks
130
+ "compression_ratio_threshold": 1.2, # Filter out low-quality transcriptions
131
+ "temperature": 0.0, # Use greedy decoding for faster inference
132
+ "no_speech_threshold": 0.6, # Threshold for detecting speech
133
+ "logprob_threshold": -1.0, # Threshold for log probability
134
+ "best_of": 1, # Use single best path for faster inference
135
+ }
136
  )
137
 
138
+ # Initialize pipeline once
139
+ pipe = create_pipeline()
140
  logger.info(f"Pipeline initialized: {pipe}")
141
 
 
142
  def transcribe(inputs, return_timestamps, generate_subs, batch_size, chunk_length_s):
143
+ """Transcribe audio input using Whisper model.
144
+
145
+ Args:
146
+ inputs (str): Path to audio file to transcribe.
147
+ return_timestamps (bool): Whether to include timestamps in output.
148
+ generate_subs (bool): Whether to generate SRT subtitles.
149
+ batch_size (int): Number of chunks to process in parallel.
150
+ chunk_length_s (int): Length of audio chunks in seconds.
151
+
152
+ Returns:
153
+ tuple: (formatted_result, srt_file, correction_text)
154
+ - formatted_result (dict): Transcription results
155
+ - srt_file (str): Path to SRT file if generated, None otherwise
156
+ - correction_text (str): Empty string for corrections
157
+
158
+ Raises:
159
+ gr.Error: If no audio file is provided or transcription fails.
160
+ """
161
  if inputs is None:
162
  logger.warning("No audio file submitted")
163
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
164
 
165
  try:
166
  logger.info(f"Processing audio file: {inputs}")
167
+
168
+ # Calculate optimal chunk and stride lengths based on input
169
+ stride_length_s = chunk_length_s / 6 # Default stride for better context
170
+
171
+ # Clear CUDA cache before processing
172
+ if torch.cuda.is_available():
173
+ torch.cuda.empty_cache()
174
+ logger.debug("Cleared CUDA cache before processing")
175
+
176
+ # Process audio with dynamic chunking
177
+ result = pipe(
178
+ inputs,
179
+ batch_size=batch_size,
180
+ chunk_length_s=chunk_length_s,
181
+ stride_length_s=stride_length_s,
182
+ return_timestamps=return_timestamps
183
+ )
184
  logger.debug(f"Pipeline result: {result}")
185
 
186
  # Format response as JSON
 
221
  srt_file = save_srt_to_file(srt_content)
222
  logger.info("SRT subtitles generated successfully")
223
 
224
+ # Clear CUDA cache after processing
225
+ if torch.cuda.is_available():
226
+ torch.cuda.empty_cache()
227
+ logger.debug("Cleared CUDA cache after processing")
228
+
229
  return formatted_result, srt_file, "" # Return empty string for correction textbox
230
  except Exception as e:
231
+ # Ensure CUDA cache is cleared even if there's an error
232
+ if torch.cuda.is_available():
233
+ torch.cuda.empty_cache()
234
+ logger.debug("Cleared CUDA cache after error")
235
  logger.exception(f"Error during transcription: {str(e)}")
236
  raise gr.Error(f"Failed to transcribe audio: {str(e)}")
237
 
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  transformers
2
- loguru
 
1
  transformers
2
+ loguru