File size: 7,695 Bytes
22e2ace
0d7c0a3
 
 
 
 
 
d1b2128
0d7c0a3
afd80fc
22e2ace
0d7c0a3
 
3cad491
 
 
 
 
 
 
 
0d7c0a3
7143888
afd80fc
 
9279bb0
a6cdcf6
c9b4e47
a6cdcf6
c9b4e47
afd80fc
 
 
 
05aa177
9279bb0
afd80fc
 
0d7c0a3
7143888
930e925
 
889e7e2
 
0d7c0a3
 
889e7e2
 
0d7c0a3
889e7e2
0d7c0a3
930e925
3cad491
 
 
 
 
 
 
0d7c0a3
 
 
 
c9b4e47
0d7c0a3
 
 
 
 
 
 
7143888
0d7c0a3
7143888
0d7c0a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9b4e47
 
 
 
 
0d7c0a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889e7e2
0d7c0a3
 
 
afd80fc
c9b4e47
 
 
 
 
3205cb6
afd80fc
930e925
c9b4e47
 
 
dd403f4
 
c9b4e47
 
 
 
 
0d7c0a3
3cad491
 
 
 
fe6e529
3cad491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe6e529
 
 
930e925
0d7c0a3
c9b4e47
0d7c0a3
 
 
 
 
 
afd80fc
 
 
 
 
 
9279bb0
afd80fc
0d7c0a3
 
 
 
 
 
 
 
 
 
 
c9b4e47
0d7c0a3
c9b4e47
 
 
0d7c0a3
 
afd80fc
 
c9b4e47
7143888
9279bb0
9c4c9fa
0d7c0a3
c9b4e47
3205cb6
d1b2128
a2da9cb
0d7c0a3
c9b4e47
0d7c0a3
 
afd80fc
0d7c0a3
 
 
7143888
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import transformers
import torch
import time

from nemo.collections.asr.models import ASRModel

from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
DESCRIPTION = '''
<div>
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://https://huggingface.co/docs/transformers/en/model_doc/vits"><b>VITS</b></a> for text to speech (TTS).</p>
<p>This demo accepts audio inputs not more than 40 seconds long.</p>
<p>Transcription and responses are limited to the English language.</p>
</div>
'''
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://i.ibb.co/S35q17Q/My-Alexa-Logo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;  "> 
   <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
</div>
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### ASR model
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
canary_model.eval()

# make sure beam size always 1 for consistency
canary_model.change_decoding_strategy(None)
decoding_cfg = canary_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
canary_model.change_decoding_strategy(decoding_cfg)

### LLM model
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")  # to("cuda:0") 
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

def convert_audio(audio_filepath, tmpdir, utt_id):
	"""
	Convert all files to monochannel 16 kHz wav files.
	Do not convert and raise error if audio is too long.
	Returns output filename and duration.
	"""

	data, sr = librosa.load(audio_filepath, sr=None, mono=True)

	duration = librosa.get_duration(y=data, sr=sr)

	if duration > MAX_AUDIO_SECONDS:
		raise gr.Error(
			f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. "
			"If you wish, you may trim the audio using the Audio viewer in Step 1 "
			"(click on the scissors icon to start trimming audio)."
		)

	if sr != SAMPLE_RATE:
		data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)

	out_filename = os.path.join(tmpdir, utt_id + '.wav')

	# save output audio
	sf.write(out_filename, data, SAMPLE_RATE)

	return out_filename, duration

def transcribe(audio_filepath):
	"""
	Transcribes a converted audio file.
	Set to english language with punctuations.
	Returns the output text.
	"""

	if audio_filepath is None:
		raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
	
	utt_id = uuid.uuid4()
	with tempfile.TemporaryDirectory() as tmpdir:
		converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))

		# make manifest file and save
		manifest_data = {
			"audio_filepath": converted_audio_filepath,
			"source_lang": "en",
			"target_lang": "en",
			"taskname": "asr",
			"pnc": "yes",
			"answer": "predict",
			"duration": str(duration),
		}

		manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')

		with open(manifest_filepath, 'w') as fout:
			line = json.dumps(manifest_data)
			fout.write(line + '\n')

		# call transcribe, passing in manifest filepath
		output_text = canary_model.transcribe(manifest_filepath)[0]

	return output_text

def add_message(history, message):
	"""
	Adds the input message in the chatbot.
	Returns the updated chatbot with an empty input textbox.
	"""
	history.append((message, None))
	return history

def bot(history,message):
	"""
	Prints the LLM's response in the chatbot
	"""
	response = bot_response(message, history, 0.7, 100)
	#response = "bot_response(message)"
	history[-1][1] = ""
	for character in response:
		history[-1][1] += character
		time.sleep(0.05)
		yield history

def bot_response(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             )
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=llama3_model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
		
    return "".join(outputs)


with gr.Blocks(
	title="MyAlexa",
	css="""
		textarea { font-size: 18px;}
	""",
	theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
) as demo:

	gr.HTML(DESCRIPTION)
	chatbot = gr.Chatbot(
        [],
        elem_id="chatbot",
        bubble_full_width=False,
		placeholder=PLACEHOLDER,
		label='MyAlexa'
    )
	with gr.Row():
		with gr.Column():
			gr.HTML(
				"<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
			)

			audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")


		with gr.Column():

			gr.HTML("<p><b>Step 2:</b> Enter audio as input and wait for MyAlexa's response.</p>")

			submit_button = gr.Button(
				value="Submit audio",
				variant="primary"
			)

			chat_input = gr.Textbox(
				label="Transcribed text:",
				interactive=False,
				placeholder="Enter message",
				elem_id="chat_input",
				visible=True
			)

	chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot])
	bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response")
	# bot_msg.then(lambda: gr.Textbox(interactive=False), None, [chat_input])
			
	submit_button.click(
		fn=transcribe, 
		inputs = [audio_file],
		outputs = [chat_input]
	)

demo.queue()
if __name__ == "__main__":
    demo.launch()