Haopeng commited on
Commit
db419eb
·
1 Parent(s): 1507eed

first commit

Browse files
Files changed (2) hide show
  1. app.py +104 -4
  2. requirement.txt +3 -0
app.py CHANGED
@@ -1,7 +1,107 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # App for summarizing the video/audio input and uploaded pdf file for joint summarization.
2
+
3
  import gradio as gr
4
+ from transformers import pipeline
5
+ import torch
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
+ import torchaudio
8
+
9
+ # get gpu device, if cuda available, then mps, last cpu
10
+ # if torch.backends.mps.is_available():
11
+ # device = torch.device('mps')
12
+ # else:
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ # torch mbp
15
+
16
+
17
+ # Initialize the Whisper model pipeline
18
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
19
+
20
+ # for filler
21
+ # load model and processor
22
+
23
+ def transcribe_with_timestamps(audio):
24
+ # Use the pipeline to transcribe the audio with timestamps
25
+ result = asr_pipeline(audio, return_timestamps="word")
26
+ return result["text"], result["chunks"]
27
+
28
+ def filler_transcribe_with_timestamps(audio, filler=False):
29
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
30
+ processor_filler = WhisperProcessor.from_pretrained("openai/whisper-base", normalize=False, return_timestamps="word")
31
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
32
+
33
+ # load dummy dataset and read audio files
34
+ sample, sr= torchaudio.load(audio)
35
+ # if sr != 16000, resample to 16000
36
+ if sr != 16000:
37
+ sample = torchaudio.transforms.Resample(sr, 16000)(sample)
38
+ sr = 16000
39
+ sample = sample.to(device)
40
+
41
+ input_features = processor(sample.squeeze(), sampling_rate=sr, return_tensors="pt").input_features
42
+
43
+ # generate token ids
44
+ # decode token ids to text with normalisation
45
+ if filler:
46
+ predicted_ids = model.generate(input_features, return_timestamps=True)
47
+ # decode token ids to text without normalisation
48
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=False)
49
+ processor.decode(predicted_ids, skip_special_tokens=True, normalize=False, decode_with_timestamps=True) # decode token ids to text without normalisation
50
+ else:
51
+ predicted_ids = model.generate(input_features)
52
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)
53
+
54
+ return transcription
55
+ # print(transcription)
56
+ # Use the pipeline to transcribe the audio with timestamps
57
+
58
+ # return result["text"], result["chunks"]
59
+ # # Set up Gradio interface
60
+ # interface = gr.Interface(
61
+ # fn=transcribe_with_timestamps,
62
+ # inputs=gr.Audio(label="Upload audio", type="filepath"),
63
+ # outputs=[gr.Textbox(label="Transcription"), gr.JSON(label="Timestamps")],
64
+ # title="Academic presentation Agent",
65
+ # )
66
+
67
+ Instructions = """
68
+ # Academic Presentation Agent
69
+ Upload a video/audio file to transcribe the audio with timestamps.
70
+ Also upload the pdf file to summarize the text. (Optional)
71
+ The model will return the transcription and timestamps of the audio.
72
+ """
73
 
74
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
+ gr.Markdown(Instructions)
76
+ with gr.Column():
77
+ with gr.Row():
78
+ input_audio = gr.Audio(label="Upload audio", type="filepath")
79
+ # Dummy PDF input
80
+ input_pdf = gr.File(label="Upload PDF", type="filepath")
81
+ with gr.Column():
82
+ with gr.Row():
83
+ transcription = gr.Textbox(label="Transcription")
84
+ with gr.Row():
85
+ with gr.Accordion(open=False):
86
+ timestamps = gr.JSON(label="Timestamps")
87
+ with gr.Row():
88
+ transcrible_button = gr.Button("Transcribe")
89
+ # ASR summary
90
+ ASR_summary = [transcription, timestamps]
91
+ transcrible_button.click(transcribe_with_timestamps, input_audio, outputs=ASR_summary)
92
+ with gr.Row():
93
+ analyze_button = gr.Button("Analyze")
94
+
95
+ # with gr.Column():
96
+ # with gr.Row():
97
+ # input_audio = gr.Audio(label="Upload audio", type="filepath")
98
+ # transcription = gr.Textbox(label="Transcription")
99
+ # timestamps = gr.JSON(label="Timestamps")
100
+ # with gr.Row():
101
+ # transcrible_button_filler = gr.Button("Transcribe_filler")
102
+ # # ASR summary
103
+ # ASR_summary = [transcription, timestamps]
104
+ # transcrible_button_filler.click(filler_transcribe_with_timestamps, input_audio, outputs=transcription)
105
 
106
+ # Launch the Gradio app
107
+ demo.launch(share=False)
requirement.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers