Update app.py
Browse filesChanged LLM to test
app.py
CHANGED
@@ -5,16 +5,11 @@ import os
|
|
5 |
import soundfile as sf
|
6 |
import tempfile
|
7 |
import uuid
|
8 |
-
import
|
9 |
import torch
|
10 |
import time
|
11 |
-
from transformers import GemmaTokenizer, AutoModelForCausalLM
|
12 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
13 |
-
from threading import Thread
|
14 |
|
15 |
from nemo.collections.asr.models import ASRModel
|
16 |
-
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
|
17 |
-
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
|
18 |
|
19 |
SAMPLE_RATE = 16000 # Hz
|
20 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
@@ -29,15 +24,14 @@ DESCRIPTION = '''
|
|
29 |
'''
|
30 |
PLACEHOLDER = """
|
31 |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
32 |
-
<img src="MyAlexaLogo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
|
33 |
<p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
|
34 |
</div>
|
35 |
"""
|
36 |
|
37 |
-
# Set an environment variable
|
38 |
-
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
39 |
-
|
40 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
41 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
42 |
canary_model.eval()
|
43 |
|
@@ -47,29 +41,14 @@ decoding_cfg = canary_model.cfg.decoding
|
|
47 |
decoding_cfg.beam.beam_size = 1
|
48 |
canary_model.change_decoding_strategy(decoding_cfg)
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
frame_asr = FrameBatchMultiTaskAED(
|
58 |
-
asr_model=canary_model,
|
59 |
-
frame_len=40.0,
|
60 |
-
total_buffer=40.0,
|
61 |
-
batch_size=16,
|
62 |
)
|
63 |
|
64 |
-
amp_dtype = torch.float16
|
65 |
-
|
66 |
-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
67 |
-
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct").to(device)
|
68 |
-
terminators = [
|
69 |
-
tokenizer.eos_token_id,
|
70 |
-
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
71 |
-
]
|
72 |
-
|
73 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
74 |
"""
|
75 |
Convert all files to monochannel 16 kHz wav files.
|
@@ -142,62 +121,49 @@ def add_message(history, message):
|
|
142 |
history.append((message, None))
|
143 |
return history, gr.Textbox(value="", interactive=False)
|
144 |
|
145 |
-
def bot(history,
|
146 |
"""
|
147 |
Prints the LLM's response in the chatbot
|
148 |
"""
|
149 |
-
response =
|
150 |
history[-1][1] = ""
|
151 |
for character in response:
|
152 |
history[-1][1] += character
|
153 |
time.sleep(0.05)
|
154 |
yield history
|
155 |
|
156 |
-
def
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
189 |
-
if temperature == 0:
|
190 |
-
generate_kwargs['do_sample'] = False
|
191 |
-
|
192 |
-
t = Thread(target=llama3_model.generate, kwargs=generate_kwargs)
|
193 |
-
t.start()
|
194 |
-
|
195 |
-
outputs = []
|
196 |
-
for text in streamer:
|
197 |
-
outputs.append(text)
|
198 |
-
#print(outputs)
|
199 |
-
yield "".join(outputs)
|
200 |
-
|
201 |
with gr.Blocks(
|
202 |
title="MyAlexa",
|
203 |
css="""
|
@@ -255,7 +221,7 @@ with gr.Blocks(
|
|
255 |
)
|
256 |
|
257 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
258 |
-
bot_msg = chat_msg.then(bot, [chatbot,
|
259 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input])
|
260 |
|
261 |
submit_button.click(
|
|
|
5 |
import soundfile as sf
|
6 |
import tempfile
|
7 |
import uuid
|
8 |
+
import transformers
|
9 |
import torch
|
10 |
import time
|
|
|
|
|
|
|
11 |
|
12 |
from nemo.collections.asr.models import ASRModel
|
|
|
|
|
13 |
|
14 |
SAMPLE_RATE = 16000 # Hz
|
15 |
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
|
|
|
24 |
'''
|
25 |
PLACEHOLDER = """
|
26 |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
27 |
+
<img src="https://huggingface.co/spaces/VanYsa/MyAlexa/blob/main/MyAlexaLogo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
|
28 |
<p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
|
29 |
</div>
|
30 |
"""
|
31 |
|
|
|
|
|
|
|
32 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
+
|
34 |
+
### ASR model
|
35 |
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
|
36 |
canary_model.eval()
|
37 |
|
|
|
41 |
decoding_cfg.beam.beam_size = 1
|
42 |
canary_model.change_decoding_strategy(decoding_cfg)
|
43 |
|
44 |
+
### LLM model
|
45 |
+
pipeline = transformers.pipeline(
|
46 |
+
"text-generation",
|
47 |
+
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
48 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
49 |
+
device=device
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def convert_audio(audio_filepath, tmpdir, utt_id):
|
53 |
"""
|
54 |
Convert all files to monochannel 16 kHz wav files.
|
|
|
121 |
history.append((message, None))
|
122 |
return history, gr.Textbox(value="", interactive=False)
|
123 |
|
124 |
+
def bot(history,message):
|
125 |
"""
|
126 |
Prints the LLM's response in the chatbot
|
127 |
"""
|
128 |
+
response = bot_response(history, message)
|
129 |
history[-1][1] = ""
|
130 |
for character in response:
|
131 |
history[-1][1] += character
|
132 |
time.sleep(0.05)
|
133 |
yield history
|
134 |
|
135 |
+
def bot_response(history, message):
|
136 |
+
"""
|
137 |
+
Generates a response from the LLM model.
|
138 |
+
Temperature and top_p are set to 0.6 and 0.9 respectively.
|
139 |
+
"""
|
140 |
+
messages = [
|
141 |
+
{"role": "system", "content": "You are a helpful AI assistant."},
|
142 |
+
{"role": "user", "content": message},
|
143 |
+
]
|
144 |
+
|
145 |
+
prompt = pipeline.tokenizer.apply_chat_template(
|
146 |
+
messages,
|
147 |
+
tokenize=False,
|
148 |
+
add_generation_prompt=True
|
149 |
+
)
|
150 |
+
|
151 |
+
terminators = [
|
152 |
+
pipeline.tokenizer.eos_token_id,
|
153 |
+
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
154 |
+
]
|
155 |
+
|
156 |
+
outputs = pipeline(
|
157 |
+
prompt,
|
158 |
+
max_new_tokens=512,
|
159 |
+
eos_token_id=terminators,
|
160 |
+
do_sample=True,
|
161 |
+
temperature=0.6,
|
162 |
+
top_p=0.9,
|
163 |
+
)
|
164 |
+
print(outputs[0]["generated_text"][len(prompt):])
|
165 |
+
return outputs[0]["generated_text"][len(prompt):]
|
166 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
with gr.Blocks(
|
168 |
title="MyAlexa",
|
169 |
css="""
|
|
|
221 |
)
|
222 |
|
223 |
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
224 |
+
bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response")
|
225 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input])
|
226 |
|
227 |
submit_button.click(
|