import os from pprint import pprint os.system("pip install git+https://github.com/openai/whisper.git") import gradio as gr import whisper from transformers import pipeline import torch from transformers import AutoModelForCausalLM from transformers import AutoTokenizer # import streaming.py # from next_word_prediction import GPT2 ### code snippet gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True) tokenizer = AutoTokenizer.from_pretrained("gpt2") ### /code snippet from share_btn import community_icon_html, loading_icon_html, share_js # get gpt2 model generator = pipeline('text-generation', model='gpt2') # whisper model specification model = whisper.load_model("tiny") def buttonValues(value): value = "Hello" return value def inference(audio): # load audio data audio = whisper.load_audio(audio) # ensure sample is in correct format for inference audio = whisper.pad_or_trim(audio) # generate a log-mel spetrogram of the audio data mel = whisper.log_mel_spectrogram(audio).to(model.device) _, probs = model.detect_language(mel) # decode audio data options = whisper.DecodingOptions(fp16 = False) # transcribe speech to text result = whisper.decode(model, mel, options) # Added prompt below input_prompt = "The following is a transcript of someone talking, please predict what they will say next. \n" ### code input_total = input_prompt + result.text input_ids = tokenizer(input_total, return_tensors="pt").input_ids print("inputs ", input_ids) # prompt length # prompt_length = len(tokenizer.decode(inputs_ids[0])) # length penalty for gpt2.generate??? #Prompt generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True) print("outputs generated ", generated_outputs[0]) # only use id's that were generated # gen_sequences has shape [3, 15] gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:] print("gen sequences: ", gen_sequences) # let's stack the logits generated at each step to a tensor and transform # logits to probs probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size] # now we need to collect the probability of the generated token # we need to add a dummy dim in the end to make gather work gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1) print("gen probs result: ", gen_probs) # now we can do all kinds of things with the probs # 1) the probs that exactly those sequences are generated again # those are normally going to be very small # unique_prob_per_sequence = gen_probs.prod(-1) # 2) normalize the probs over the three sequences # normed_gen_probs = gen_probs / gen_probs.sum(0) # assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized" # 3) compare normalized probs to each other like in 1) # unique_normed_prob_per_sequence = normed_gen_probs.prod(-1) ### end code # print audio data as text # print(result.text) # prompt getText = generator(result.text, max_new_tokens=10, num_return_sequences=5) # pprint("getText: ", getText) # pprint("text.result: ", result.text) # result.text return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: black; background: black; } input[type='range'] { accent-color: black; } .dark input[type='range'] { accent-color: #dfdfdf; } .container { max-width: 730px; margin: auto; padding-top: 1.5rem; } .details:hover { text-decoration: underline; } .gr-button { white-space: nowrap; } .gr-button:focus { border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); --tw-ring-opacity: .5; } .footer { margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } .prompt h4{ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%; } .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; margin-top: 1.5rem !important; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; } #share-btn * { all: unset; } """ block = gr.Blocks(css=css) with block: gr.HTML( """

Whisper

Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification. This demo cuts audio after around 30 secs.

You can skip the queue by using google colab for the space: Open In Colab

""" ) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): # get audio from microphone audio = gr.Audio( label="Input Audio", show_label=False, source="microphone", type="filepath" ) btn = gr.Button("Transcribe") text = gr.Textbox(show_label=False, elem_id="result-textarea") # added rText below # rText = gr.Textbox(show_label=False, elem_id="result-textarea") buttonV = gr.Button(" ") buttonV.click(buttonValues, inputs=[], outputs=[]) btn.click(inference, inputs=[audio], outputs=[text]) gr.HTML(''' ''') block.launch()