reab5555 commited on
Commit
818cd17
·
verified ·
1 Parent(s): 117dfa6

Update transcription_diarization.py

Browse files
Files changed (1) hide show
  1. transcription_diarization.py +115 -184
transcription_diarization.py CHANGED
@@ -1,192 +1,123 @@
 
 
 
1
  import os
2
- import torch
3
- import gc
4
- from moviepy.editor import VideoFileClip
5
- from pyannote.audio import Pipeline
6
- import datetime
7
- from collections import defaultdict
8
- from openai import OpenAI
9
- from config import openai_api_key, hf_token
10
- from pydub import AudioSegment, silence
11
- import math
12
-
13
- client = OpenAI(api_key=openai_api_key)
14
-
15
- class LazyDiarizationPipeline:
16
- def __init__(self):
17
- self.pipeline = None
18
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
- def get_pipeline(self, hf_token):
21
- if self.pipeline is None:
22
- self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
23
- use_auth_token=hf_token)
24
- self.pipeline = self.pipeline.to(self.device)
25
- return self.pipeline
26
-
27
- lazy_diarization_pipeline = LazyDiarizationPipeline()
28
-
29
- def extract_audio(video_path):
30
- base_name = os.path.splitext(video_path)[0]
31
- audio_path = f"{base_name}.wav"
32
- video = VideoFileClip(video_path)
33
- audio = video.audio
34
- # Reduce audio quality to keep file size smaller
35
- audio.write_audiofile(audio_path, codec='pcm_s16le', fps=16000, nbytes=2)
36
- return audio_path
37
-
38
- def format_timestamp(seconds):
39
- return str(datetime.timedelta(seconds=round(seconds))).zfill(8)
40
-
41
- def diarize_audio(audio_path, pipeline, max_speakers):
42
- diarization = pipeline(audio_path, num_speakers=max_speakers)
43
- return diarization
44
-
45
- def split_audio_on_silence(audio_path, min_silence_len=500, silence_thresh=-40, keep_silence=500):
46
- audio = AudioSegment.from_wav(audio_path)
47
- chunks = silence.split_on_silence(
48
- audio,
49
- min_silence_len=min_silence_len,
50
- silence_thresh=silence_thresh,
51
- keep_silence=keep_silence
52
  )
53
 
54
- chunk_paths = []
55
- for i, chunk in enumerate(chunks):
56
- chunk_path = f"{audio_path[:-4]}_chunk_{i}.wav"
57
- chunk.export(chunk_path, format="wav")
58
- chunk_paths.append(chunk_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- return chunk_paths
61
-
62
- def transcribe_audio(audio_path, language):
63
- with open(audio_path, "rb") as audio_file:
64
- transcript = client.audio.transcriptions.create(
65
- file=audio_file,
66
- model="whisper-1",
67
- language=language,
68
- response_format="verbose_json"
69
- )
70
-
71
- if not isinstance(transcript, dict):
72
- transcript = transcript.model_dump()
73
-
74
- transcription_txt = transcript.get("text", "")
75
- transcription_chunks = []
76
-
77
- for segment in transcript.get("segments", []):
78
- transcription_chunks.append({
79
- "start": segment.get("start", 0),
80
- "end": segment.get("end", 0),
81
- "text": segment.get("text", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  })
83
 
84
- return transcription_txt, transcription_chunks
85
 
86
- def transcribe_large_audio(audio_path, language):
87
- chunk_paths = split_audio_on_silence(audio_path)
88
- transcription_txt = ""
89
- transcription_chunks = []
90
 
91
- for chunk_path in chunk_paths:
92
- chunk_transcription_txt, chunk_transcription_chunks = transcribe_audio(chunk_path, language)
93
- transcription_txt += chunk_transcription_txt
94
- transcription_chunks.extend(chunk_transcription_chunks)
95
- os.remove(chunk_path)
96
-
97
- return transcription_txt, transcription_chunks
98
-
99
- def create_combined_srt(transcription_chunks, diarization, output_path, max_speakers):
100
- speaker_segments = []
101
- speaker_durations = defaultdict(float)
102
-
103
- for segment, _, speaker in diarization.itertracks(yield_label=True):
104
- speaker_durations[speaker] += segment.end - segment.start
105
- speaker_segments.append((segment.start, segment.end, speaker))
106
-
107
- sorted_speakers = sorted(speaker_durations.items(), key=lambda x: x[1], reverse=True)[:max_speakers]
108
-
109
- speaker_map = {}
110
- for i, (speaker, _) in enumerate(sorted_speakers, start=1):
111
- speaker_map[speaker] = f"Speaker {i}"
112
-
113
- with open(output_path, 'w', encoding='utf-8') as srt_file:
114
- current_speaker = "Unknown"
115
- current_text = ""
116
- current_start = 0
117
- current_end = 0
118
- entry_count = 0
119
-
120
- def write_entry():
121
- nonlocal entry_count, current_speaker, current_start, current_end, current_text
122
- if current_text:
123
- entry_count += 1
124
- start_str = format_timestamp(current_start)
125
- end_str = format_timestamp(current_end)
126
- srt_file.write(f"[{entry_count}. {current_speaker} | time: ({start_str} --> {end_str}) | text: {current_text.strip()}]\n\n")
127
-
128
- for chunk in transcription_chunks:
129
- start_time, end_time = chunk["start"], chunk["end"]
130
- text = chunk["text"]
131
-
132
- # Avoid splitting mid-sentence
133
- if current_text and (text[0].isupper() or text.startswith(('.', '?', '!', '...'))):
134
- write_entry()
135
- current_speaker = "Unknown"
136
- for seg_start, seg_end, speaker in speaker_segments:
137
- if seg_start <= start_time < seg_end:
138
- current_speaker = speaker_map.get(speaker, "Unknown")
139
- break
140
-
141
- current_text = ""
142
- current_start = start_time
143
-
144
- current_text += " " + text
145
- current_end = end_time
146
-
147
- # Write entry if sentence ends with a punctuation mark
148
- if current_text.strip().endswith(('.', '?', '!', '...')):
149
- write_entry()
150
- current_text = ""
151
- current_start = end_time
152
-
153
- write_entry()
154
-
155
- with open(output_path, 'a', encoding='utf-8') as srt_file:
156
- for i, (speaker, duration) in enumerate(sorted_speakers, start=1):
157
- duration_str = format_timestamp(duration)
158
- srt_file.write(f"Speaker {i} (originally {speaker}): total duration {duration_str}\n")
159
-
160
- def process_video(video_path, hf_token, language, max_speakers=3):
161
- audio_path = extract_audio(video_path)
162
-
163
- pipeline = lazy_diarization_pipeline.get_pipeline(hf_token)
164
- diarization = diarize_audio(audio_path, pipeline, max_speakers)
165
-
166
- if torch.cuda.is_available():
167
- torch.cuda.empty_cache()
168
- gc.collect()
169
-
170
- transcription, chunks = transcribe_large_audio(audio_path, language)
171
-
172
- if torch.cuda.is_available():
173
- torch.cuda.empty_cache()
174
- gc.collect()
175
-
176
- combined_srt_path = f"{os.path.splitext(video_path)[0]}_combined.srt"
177
- create_combined_srt(chunks, diarization, combined_srt_path, max_speakers)
178
-
179
- os.remove(audio_path)
180
-
181
- if torch.cuda.is_available():
182
- torch.cuda.empty_cache()
183
- gc.collect()
184
-
185
- # Convert the diarization results to a readable format
186
- diarization_output = ""
187
- for turn, _, speaker in diarization.itertracks(yield_label=True):
188
- start_time = format_timestamp(turn.start)
189
- end_time = format_timestamp(turn.end)
190
- diarization_output += f"Speaker {speaker}: {start_time} --> {end_time}\n"
191
-
192
- return combined_srt_path, diarization_output
 
1
+ import boto3
2
+ import time
3
+ import json
4
  import os
5
+ from config import aws_access_key_id, aws_secret_access_key
6
+
7
+ def upload_to_s3(local_file_path, bucket_name, s3_file_key):
8
+ s3_client = boto3.client('s3',
9
+ aws_access_key_id=aws_access_key_id,
10
+ aws_secret_access_key=aws_secret_access_key,
11
+ region_name='eu-central-1')
12
+ s3_client.upload_file(local_file_path, bucket_name, s3_file_key)
13
+ return f's3://{bucket_name}/{s3_file_key}'
14
+
15
+ def transcribe_video(file_uri, job_name, max_speakers):
16
+ transcribe = boto3.client('transcribe',
17
+ aws_access_key_id=aws_access_key_id,
18
+ aws_secret_access_key=aws_secret_access_key,
19
+ region_name='eu-central-1')
20
+
21
+ transcribe.start_transcription_job(
22
+ TranscriptionJobName=job_name,
23
+ Media={'MediaFileUri': file_uri},
24
+ MediaFormat='mp4',
25
+ IdentifyLanguage=True,
26
+ Settings={
27
+ 'ShowSpeakerLabels': True,
28
+ 'MaxSpeakerLabels': max_speakers
29
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
+ while True:
33
+ status = transcribe.get_transcription_job(TranscriptionJobName=job_name)
34
+ if status['TranscriptionJob']['TranscriptionJobStatus'] in ['COMPLETED', 'FAILED']:
35
+ break
36
+ print("Waiting for transcription to complete...")
37
+ time.sleep(30)
38
+
39
+ if status['TranscriptionJob']['TranscriptionJobStatus'] == 'COMPLETED':
40
+ transcript_url = status['TranscriptionJob']['Transcript']['TranscriptFileUri']
41
+ print("Transcription completed successfully!")
42
+ return transcript_url
43
+ else:
44
+ print("Transcription failed.")
45
+ return None
46
+
47
+ def download_transcript(transcript_url):
48
+ s3_client = boto3.client('s3',
49
+ aws_access_key_id=aws_access_key_id,
50
+ aws_secret_access_key=aws_secret_access_key,
51
+ region_name='eu-central-1')
52
 
53
+ bucket_name = transcript_url.split('/')[2]
54
+ key = '/'.join(transcript_url.split('/')[3:])
55
+
56
+ response = s3_client.get_object(Bucket=bucket_name, Key=key)
57
+ transcript_content = response['Body'].read().decode('utf-8')
58
+ return json.loads(transcript_content)
59
+
60
+ def extract_transcriptions_with_speakers(transcript_data):
61
+ segments = transcript_data['results']['speaker_labels']['segments']
62
+ items = transcript_data['results']['items']
63
+
64
+ current_speaker = None
65
+ current_text = []
66
+ transcriptions = []
67
+
68
+ for item in items:
69
+ if item['type'] == 'pronunciation':
70
+ start_time = float(item['start_time'])
71
+ end_time = float(item['end_time'])
72
+ content = item['alternatives'][0]['content']
73
+
74
+ speaker_segment = next((seg for seg in segments if float(seg['start_time']) <= start_time and float(seg['end_time']) >= end_time), None)
75
+
76
+ if speaker_segment and speaker_segment['speaker_label'] != current_speaker:
77
+ if current_text:
78
+ transcriptions.append({
79
+ 'speaker': current_speaker,
80
+ 'text': ' '.join(current_text)
81
+ })
82
+ current_text = []
83
+ current_speaker = speaker_segment['speaker_label']
84
+
85
+ current_text.append(content)
86
+ elif item['type'] == 'punctuation':
87
+ current_text[-1] += item['alternatives'][0]['content']
88
+
89
+ if current_text:
90
+ transcriptions.append({
91
+ 'speaker': current_speaker,
92
+ 'text': ' '.join(current_text)
93
  })
94
 
95
+ return transcriptions
96
 
97
+ def process_video(video_path, bucket_name, max_speakers):
98
+ # Upload video to S3
99
+ s3_file_key = os.path.basename(video_path)
100
+ file_uri = upload_to_s3(video_path, bucket_name, s3_file_key)
101
 
102
+ # Start transcription job
103
+ job_name = f'transcription_job_{int(time.time())}'
104
+ transcript_url = transcribe_video(file_uri, job_name, max_speakers)
105
+
106
+ if transcript_url:
107
+ # Download and process transcript
108
+ transcript_data = download_transcript(transcript_url)
109
+ transcriptions = extract_transcriptions_with_speakers(transcript_data)
110
+
111
+ # Create combined SRT-like output
112
+ output = []
113
+ for i, trans in enumerate(transcriptions, 1):
114
+ output.append(f"[{i}. {trans['speaker']} | text: {trans['text']}]\n")
115
+
116
+ return '\n'.join(output)
117
+ else:
118
+ return "Transcription failed."
119
+
120
+ # This function will be called from the Gradio app
121
+ def diarize_audio(video_path, max_speakers):
122
+ bucket_name = 'transcriptionjobbucket' # Replace with your actual S3 bucket name
123
+ return process_video(video_path, bucket_name, max_speakers)