reab5555 commited on
Commit
2821cda
·
verified ·
1 Parent(s): 34eebab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -111
app.py CHANGED
@@ -6,145 +6,230 @@ import torch
6
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
  from moviepy.editor import VideoFileClip
8
 
9
- def transcribe(video_file, transcribe_to_text, transcribe_to_srt, language):
10
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
- model_id = "openai/whisper-large-v3"
13
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
14
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
15
- )
16
- model.to(device)
17
- processor = AutoProcessor.from_pretrained(model_id)
18
- pipe = pipeline(
19
- "automatic-speech-recognition",
20
- model=model,
21
- tokenizer=processor.tokenizer,
22
- feature_extractor=processor.feature_extractor,
23
- max_new_tokens=128,
24
- chunk_length_s=10,
25
- batch_size=2,
26
- return_timestamps=True,
27
- torch_dtype=torch_dtype,
28
- device=device,
29
- )
30
-
31
- if video_file is None:
32
- yield "Error: No video file provided.", None
33
- return
34
-
35
- video_path = video_file.name if hasattr(video_file, 'name') else video_file
36
- try:
37
- video = VideoFileClip(video_path)
38
- except Exception as e:
39
- yield f"Error processing video file: {str(e)}", None
40
- return
41
-
42
- audio = video.audio
43
- duration = video.duration
44
- n_chunks = math.ceil(duration / 10)
45
- transcription_txt = ""
46
- transcription_srt = []
47
-
48
- for i in range(n_chunks):
49
- start = i * 10
50
- end = min((i + 1) * 10, duration)
51
- audio_chunk = audio.subclip(start, end)
52
-
53
- temp_file_path = f"temp_audio_{i}.wav"
54
- audio_chunk.write_audiofile(temp_file_path, codec='pcm_s16le')
55
-
56
- with open(temp_file_path, "rb") as temp_file:
57
- result = pipe(temp_file_path, generate_kwargs={"language": language})
58
- transcription_txt += result["text"]
59
- if transcribe_to_srt:
60
- for chunk in result["chunks"]:
61
- start_time, end_time = chunk["timestamp"]
62
- if start_time is not None and end_time is not None:
63
- transcription_srt.append({
64
- "start": start_time + i * 10,
65
- "end": end_time + i * 10,
66
- "text": chunk["text"]
67
- })
68
- else:
69
- print(f"Warning: Invalid timestamp for chunk: {chunk}")
70
-
71
- os.remove(temp_file_path)
72
- yield f"Progress: {int(((i + 1) / n_chunks) * 100)}%", None
73
-
74
- output = ""
75
- srt_file_path = None
76
- if transcribe_to_text:
77
- output += "Text Transcription:\n" + transcription_txt + "\n\n"
78
- if transcribe_to_srt:
79
- output += "SRT Transcription:\n"
80
- srt_content = ""
81
- for i, sub in enumerate(transcription_srt, 1):
82
- srt_entry = f"{i}\n{format_time(sub['start'])} --> {format_time(sub['end'])}\n{sub['text']}\n\n"
83
- srt_content += srt_entry
84
-
85
- # Remove duplicate captions and keep only the last occurrence
86
- cleaned_srt_content = clean_srt_duplicates(srt_content)
87
-
88
- # Save SRT content to a file
89
- srt_file_path = "transcription.srt"
90
- with open(srt_file_path, "w", encoding="utf-8") as srt_file:
91
- srt_file.write(cleaned_srt_content)
92
-
93
- output += f"\nSRT file saved as: {srt_file_path}"
94
 
95
- yield output, srt_file_path
 
 
 
 
 
 
96
 
97
  def format_time(seconds):
 
98
  m, s = divmod(seconds, 60)
99
  h, m = divmod(m, 60)
100
  return f"{int(h):02d}:{int(m):02d}:{s:06.3f}".replace('.', ',')
101
 
102
- def clean_srt_duplicates(srt_content, time_threshold=30):
103
  """
104
- Function to remove duplicate captions within a specified time range in SRT format,
105
  keeping only the last occurrence.
106
  """
107
- cleaned_srt = []
108
- last_seen = {}
109
-
110
- # Pattern to match each SRT block
111
- srt_pattern = re.compile(r"(\d+)\n(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.+)", re.DOTALL)
112
 
113
- # Store blocks temporarily to determine duplicates
114
  blocks = []
 
 
115
  for match in srt_pattern.finditer(srt_content):
116
  index, start_time, end_time, text = match.groups()
117
  text = text.strip()
118
 
119
- # Convert start time to seconds
120
- start_seconds = sum(int(x) * 60 ** i for i, x in enumerate(reversed(start_time.split(":"))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Only keep the last instance within the time threshold
123
- if text in last_seen and (start_seconds - last_seen[text]) < time_threshold:
124
- blocks.pop() # Remove the previous occurrence
125
- blocks.append((index, start_time, end_time, text))
126
- last_seen[text] = start_seconds # Update last occurrence time
127
 
128
- # Build cleaned SRT content
129
- for i, (index, start_time, end_time, text) in enumerate(blocks, 1):
130
- cleaned_srt.append(f"{i}\n{start_time} --> {end_time}\n{text}\n\n")
 
 
 
 
 
131
 
132
- return ''.join(cleaned_srt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
 
134
  iface = gr.Interface(
135
  fn=transcribe,
136
  inputs=[
137
- gr.Video(),
138
- gr.Checkbox(label="Transcribe to Text"),
139
- gr.Checkbox(label="Transcribe to SRT"),
140
- gr.Dropdown(choices=['en', 'he', 'it', 'es', 'fr', 'de', 'zh', 'ar'], label="Language")
 
 
 
 
141
  ],
142
  outputs=[
143
  gr.Textbox(label="Transcription Output"),
144
  gr.File(label="Download SRT")
145
  ],
146
  title="WhisperCap Video Transcription",
147
- description="Upload a video file to transcribe its audio using Whisper. You can download the SRT file if generated.",
 
 
 
 
 
 
148
  )
149
 
150
- iface.launch(share=True)
 
 
 
6
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
  from moviepy.editor import VideoFileClip
8
 
9
+ def timestamp_to_seconds(timestamp):
10
+ """Convert SRT timestamp to seconds"""
11
+ # Split hours, minutes, and seconds (with milliseconds)
12
+ hours, minutes, rest = timestamp.split(':')
13
+ # Handle seconds and milliseconds (separated by comma)
14
+ seconds, milliseconds = rest.split(',')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ total_seconds = (
17
+ int(hours) * 3600 +
18
+ int(minutes) * 60 +
19
+ int(seconds) +
20
+ int(milliseconds) / 1000
21
+ )
22
+ return total_seconds
23
 
24
  def format_time(seconds):
25
+ """Convert seconds to SRT timestamp format"""
26
  m, s = divmod(seconds, 60)
27
  h, m = divmod(m, 60)
28
  return f"{int(h):02d}:{int(m):02d}:{s:06.3f}".replace('.', ',')
29
 
30
+ def clean_srt_duplicates(srt_content, time_threshold=30, similarity_threshold=0.9):
31
  """
32
+ Remove duplicate captions within a specified time range in SRT format,
33
  keeping only the last occurrence.
34
  """
35
+ # Pattern to match each SRT block, including newlines in text
36
+ srt_pattern = re.compile(r"(\d+)\n(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.*?)(?=\n\n|\Z)", re.DOTALL)
 
 
 
37
 
38
+ # Store blocks with their timing information
39
  blocks = []
40
+ seen_texts = {} # Track last occurrence of each text
41
+
42
  for match in srt_pattern.finditer(srt_content):
43
  index, start_time, end_time, text = match.groups()
44
  text = text.strip()
45
 
46
+ # Convert start time to seconds for comparison
47
+ start_seconds = timestamp_to_seconds(start_time)
48
+
49
+ # Check for similar existing captions within the time threshold
50
+ is_duplicate = False
51
+ for existing_text, (existing_time, existing_idx) in list(seen_texts.items()):
52
+ time_diff = abs(start_seconds - existing_time)
53
+
54
+ # Check if texts are identical or very similar
55
+ if (text == existing_text or
56
+ (len(text) > 0 and len(existing_text) > 0 and
57
+ (text in existing_text or existing_text in text))):
58
+ if time_diff < time_threshold:
59
+ # Remove the previous occurrence if this is a duplicate
60
+ blocks = [b for b in blocks if b[0] != str(existing_idx)]
61
+ is_duplicate = True
62
+ break
63
+
64
+ if not is_duplicate or start_seconds - seen_texts.get(text, (0, 0))[0] >= time_threshold:
65
+ blocks.append((index, start_time, end_time, text))
66
+ seen_texts[text] = (start_seconds, len(blocks))
67
+
68
+ # Rebuild the SRT content with proper formatting and sequential numbering
69
+ cleaned_srt = []
70
+ for i, (_, start_time, end_time, text) in enumerate(blocks, 1):
71
+ cleaned_srt.append(f"{i}\n{start_time} --> {end_time}\n{text}\n\n")
72
+
73
+ return ''.join(cleaned_srt)
74
+
75
+ def transcribe(video_file, transcribe_to_text, transcribe_to_srt, language):
76
+ """
77
+ Main transcription function that processes video files and generates
78
+ text and/or SRT transcriptions.
79
+ """
80
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
81
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
82
+ model_id = "openai/whisper-large-v3"
83
+
84
+ try:
85
+ # Initialize model and processor
86
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
87
+ model_id,
88
+ torch_dtype=torch_dtype,
89
+ low_cpu_mem_usage=True,
90
+ use_safetensors=True
91
+ )
92
+ model.to(device)
93
+
94
+ processor = AutoProcessor.from_pretrained(model_id)
95
+
96
+ pipe = pipeline(
97
+ "automatic-speech-recognition",
98
+ model=model,
99
+ tokenizer=processor.tokenizer,
100
+ feature_extractor=processor.feature_extractor,
101
+ max_new_tokens=128,
102
+ chunk_length_s=10,
103
+ batch_size=2,
104
+ return_timestamps=True,
105
+ torch_dtype=torch_dtype,
106
+ device=device,
107
+ )
108
 
109
+ if video_file is None:
110
+ yield "Error: No video file provided.", None
111
+ return
 
 
112
 
113
+ # Handle video file path
114
+ video_path = video_file.name if hasattr(video_file, 'name') else video_file
115
+
116
+ try:
117
+ video = VideoFileClip(video_path)
118
+ except Exception as e:
119
+ yield f"Error processing video file: {str(e)}", None
120
+ return
121
 
122
+ # Process video in chunks
123
+ audio = video.audio
124
+ duration = video.duration
125
+ n_chunks = math.ceil(duration / 10)
126
+ transcription_txt = ""
127
+ transcription_srt = []
128
+
129
+ for i in range(n_chunks):
130
+ start = i * 10
131
+ end = min((i + 1) * 10, duration)
132
+ audio_chunk = audio.subclip(start, end)
133
+
134
+ temp_file_path = f"temp_audio_{i}.wav"
135
+
136
+ try:
137
+ # Save audio chunk to temporary file
138
+ audio_chunk.write_audiofile(
139
+ temp_file_path,
140
+ codec='pcm_s16le',
141
+ verbose=False,
142
+ logger=None
143
+ )
144
+
145
+ # Process audio chunk
146
+ with open(temp_file_path, "rb") as temp_file:
147
+ result = pipe(
148
+ temp_file_path,
149
+ generate_kwargs={"language": language}
150
+ )
151
+
152
+ transcription_txt += result["text"]
153
+
154
+ if transcribe_to_srt:
155
+ for chunk in result["chunks"]:
156
+ start_time, end_time = chunk["timestamp"]
157
+ if start_time is not None and end_time is not None:
158
+ transcription_srt.append({
159
+ "start": start_time + i * 10,
160
+ "end": end_time + i * 10,
161
+ "text": chunk["text"].strip()
162
+ })
163
+
164
+ finally:
165
+ # Clean up temporary file
166
+ if os.path.exists(temp_file_path):
167
+ os.remove(temp_file_path)
168
+
169
+ # Report progress
170
+ yield f"Progress: {int(((i + 1) / n_chunks) * 100)}%", None
171
+
172
+ # Prepare output
173
+ output = ""
174
+ srt_file_path = None
175
+
176
+ if transcribe_to_text:
177
+ output += "Text Transcription:\n" + transcription_txt.strip() + "\n\n"
178
+
179
+ if transcribe_to_srt:
180
+ output += "SRT Transcription:\n"
181
+ srt_content = ""
182
+
183
+ # Generate initial SRT content
184
+ for i, sub in enumerate(transcription_srt, 1):
185
+ srt_entry = f"{i}\n{format_time(sub['start'])} --> {format_time(sub['end'])}\n{sub['text']}\n\n"
186
+ srt_content += srt_entry
187
+
188
+ # Clean up duplicates
189
+ cleaned_srt_content = clean_srt_duplicates(srt_content)
190
+
191
+ # Save SRT content to file
192
+ srt_file_path = "transcription.srt"
193
+ with open(srt_file_path, "w", encoding="utf-8") as srt_file:
194
+ srt_file.write(cleaned_srt_content)
195
+
196
+ output += f"\nSRT file saved as: {srt_file_path}"
197
+
198
+ # Clean up video object
199
+ video.close()
200
+
201
+ yield output, srt_file_path
202
+
203
+ except Exception as e:
204
+ yield f"Error during transcription: {str(e)}", None
205
 
206
+ # Create Gradio interface
207
  iface = gr.Interface(
208
  fn=transcribe,
209
  inputs=[
210
+ gr.Video(label="Upload Video"),
211
+ gr.Checkbox(label="Transcribe to Text", value=True),
212
+ gr.Checkbox(label="Transcribe to SRT", value=True),
213
+ gr.Dropdown(
214
+ choices=['en', 'he', 'it', 'es', 'fr', 'de', 'zh', 'ar'],
215
+ value='en',
216
+ label="Language"
217
+ )
218
  ],
219
  outputs=[
220
  gr.Textbox(label="Transcription Output"),
221
  gr.File(label="Download SRT")
222
  ],
223
  title="WhisperCap Video Transcription",
224
+ description="""
225
+ Upload a video file to transcribe its audio using Whisper Large V3.
226
+ You can generate both text and SRT format transcriptions.
227
+ Supported languages: English (en), Hebrew (he), Italian (it), Spanish (es),
228
+ French (fr), German (de), Chinese (zh), Arabic (ar)
229
+ """,
230
+ allow_flagging="never"
231
  )
232
 
233
+ # Launch the interface
234
+ if __name__ == "__main__":
235
+ iface.launch(share=True)