bluenevus commited on
Commit
fce37ea
·
verified ·
1 Parent(s): 640b5e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import io
2
  import torch
3
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
  import requests
5
  from bs4 import BeautifulSoup
6
  import tempfile
@@ -20,9 +20,14 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"Using device: {device}")
21
 
22
  # Load the Whisper model and processor
23
- model_name = "openai/whisper-small"
24
- processor = WhisperProcessor.from_pretrained(model_name)
25
- model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
 
 
 
 
 
26
 
27
  def download_audio_from_url(url):
28
  try:
@@ -65,9 +70,9 @@ def transcribe_audio(audio_file):
65
  audio_array = audio.get_array_of_samples()
66
 
67
  print("Starting transcription...")
68
- input_features = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
69
- predicted_ids = model.generate(input_features)
70
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
71
 
72
  print(f"Transcription complete. Length: {len(transcription[0])} characters")
73
  return transcription[0]
@@ -75,6 +80,28 @@ def transcribe_audio(audio_file):
75
  print(f"Error in transcribe_audio: {str(e)}")
76
  raise
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def transcribe_video(url):
79
  try:
80
  print(f"Attempting to download audio from URL: {url}")
@@ -86,7 +113,11 @@ def transcribe_video(url):
86
  transcript = transcribe_audio(temp_audio.name)
87
 
88
  os.unlink(temp_audio.name)
89
- return transcript
 
 
 
 
90
  except Exception as e:
91
  error_message = f"An error occurred: {str(e)}"
92
  print(error_message)
@@ -97,7 +128,7 @@ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
97
  app.layout = dbc.Container([
98
  dbc.Row([
99
  dbc.Col([
100
- html.H1("Video Transcription", className="text-center mb-4"),
101
  dbc.Card([
102
  dbc.CardBody([
103
  dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
@@ -139,7 +170,7 @@ def update_transcription(n_clicks, url):
139
  download_data = dict(content=transcript, filename="transcript.txt")
140
  return dbc.Card([
141
  dbc.CardBody([
142
- html.H5("Transcription Result"),
143
  html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
144
  dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
145
  ])
 
1
  import io
2
  import torch
3
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
4
  import requests
5
  from bs4 import BeautifulSoup
6
  import tempfile
 
20
  print(f"Using device: {device}")
21
 
22
  # Load the Whisper model and processor
23
+ whisper_model_name = "openai/whisper-small"
24
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
25
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
26
+
27
+ # Load the Qwen model and tokenizer
28
+ qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
29
+ qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
30
+ qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name).to(device)
31
 
32
  def download_audio_from_url(url):
33
  try:
 
70
  audio_array = audio.get_array_of_samples()
71
 
72
  print("Starting transcription...")
73
+ input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
74
+ predicted_ids = whisper_model.generate(input_features)
75
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
76
 
77
  print(f"Transcription complete. Length: {len(transcription[0])} characters")
78
  return transcription[0]
 
80
  print(f"Error in transcribe_audio: {str(e)}")
81
  raise
82
 
83
+ def separate_speakers(transcription):
84
+ prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:
85
+
86
+ 1. Label speakers as "Speaker 1", "Speaker 2", etc.
87
+ 2. Start each speaker's text on a new line beginning with their label.
88
+ 3. Separate different speakers' contributions with a blank line.
89
+ 4. If the same speaker continues, do not insert a blank line or repeat the speaker label.
90
+
91
+ Now, please process the following transcribed text:
92
+
93
+ {transcription}
94
+ """
95
+
96
+ inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
97
+ outputs = qwen_model.generate(**inputs, max_new_tokens=1000)
98
+ result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ # Extract the processed text (remove the instruction part)
101
+ processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
102
+
103
+ return processed_text
104
+
105
  def transcribe_video(url):
106
  try:
107
  print(f"Attempting to download audio from URL: {url}")
 
113
  transcript = transcribe_audio(temp_audio.name)
114
 
115
  os.unlink(temp_audio.name)
116
+
117
+ print("Separating speakers...")
118
+ separated_transcript = separate_speakers(transcript)
119
+
120
+ return separated_transcript
121
  except Exception as e:
122
  error_message = f"An error occurred: {str(e)}"
123
  print(error_message)
 
128
  app.layout = dbc.Container([
129
  dbc.Row([
130
  dbc.Col([
131
+ html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
132
  dbc.Card([
133
  dbc.CardBody([
134
  dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
 
170
  download_data = dict(content=transcript, filename="transcript.txt")
171
  return dbc.Card([
172
  dbc.CardBody([
173
+ html.H5("Transcription Result with Speaker Separation"),
174
  html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
175
  dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
176
  ])