VanYsa commited on
Commit
930e925
·
1 Parent(s): a6cdcf6

Update app.py

Browse files

Changed LLM to test

Files changed (1) hide show
  1. app.py +45 -79
app.py CHANGED
@@ -5,16 +5,11 @@ import os
5
  import soundfile as sf
6
  import tempfile
7
  import uuid
8
- import os
9
  import torch
10
  import time
11
- from transformers import GemmaTokenizer, AutoModelForCausalLM
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
13
- from threading import Thread
14
 
15
  from nemo.collections.asr.models import ASRModel
16
- from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
17
- from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
18
 
19
  SAMPLE_RATE = 16000 # Hz
20
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
@@ -29,15 +24,14 @@ DESCRIPTION = '''
29
  '''
30
  PLACEHOLDER = """
31
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
32
- <img src="MyAlexaLogo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
33
  <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
34
  </div>
35
  """
36
 
37
- # Set an environment variable
38
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
39
-
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
41
  canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
42
  canary_model.eval()
43
 
@@ -47,29 +41,14 @@ decoding_cfg = canary_model.cfg.decoding
47
  decoding_cfg.beam.beam_size = 1
48
  canary_model.change_decoding_strategy(decoding_cfg)
49
 
50
- # setup for buffered inference
51
- canary_model.cfg.preprocessor.dither = 0.0
52
- canary_model.cfg.preprocessor.pad_to = 0
53
-
54
- feature_stride = canary_model.cfg.preprocessor['window_stride']
55
- model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
56
-
57
- frame_asr = FrameBatchMultiTaskAED(
58
- asr_model=canary_model,
59
- frame_len=40.0,
60
- total_buffer=40.0,
61
- batch_size=16,
62
  )
63
 
64
- amp_dtype = torch.float16
65
-
66
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
67
- llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct").to(device)
68
- terminators = [
69
- tokenizer.eos_token_id,
70
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
71
- ]
72
-
73
  def convert_audio(audio_filepath, tmpdir, utt_id):
74
  """
75
  Convert all files to monochannel 16 kHz wav files.
@@ -142,62 +121,49 @@ def add_message(history, message):
142
  history.append((message, None))
143
  return history, gr.Textbox(value="", interactive=False)
144
 
145
- def bot(history, message):
146
  """
147
  Prints the LLM's response in the chatbot
148
  """
149
- response = chat_llama3_8b(message, history, 0.95, 512)
150
  history[-1][1] = ""
151
  for character in response:
152
  history[-1][1] += character
153
  time.sleep(0.05)
154
  yield history
155
 
156
- def chat_llama3_8b(message: str,
157
- history: list,
158
- temperature: float,
159
- max_new_tokens: int
160
- ) -> str: # type: ignore
161
- """
162
- Generate a streaming response using the llama3-8b model.
163
- Args:
164
- message (str): The input message.
165
- history (list): The conversation history used by ChatInterface.
166
- temperature (float): The temperature for generating the response.
167
- max_new_tokens (int): The maximum number of new tokens to generate.
168
- Returns:
169
- str: The generated response.
170
- """
171
- conversation = []
172
- for user, assistant in history:
173
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
174
- conversation.append({"role": "user", "content": message})
175
-
176
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
177
-
178
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
179
-
180
- generate_kwargs = dict(
181
- input_ids= input_ids,
182
- streamer=streamer,
183
- max_new_tokens=max_new_tokens,
184
- do_sample=True,
185
- temperature=temperature,
186
- eos_token_id=terminators,
187
- )
188
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
189
- if temperature == 0:
190
- generate_kwargs['do_sample'] = False
191
-
192
- t = Thread(target=llama3_model.generate, kwargs=generate_kwargs)
193
- t.start()
194
-
195
- outputs = []
196
- for text in streamer:
197
- outputs.append(text)
198
- #print(outputs)
199
- yield "".join(outputs)
200
-
201
  with gr.Blocks(
202
  title="MyAlexa",
203
  css="""
@@ -255,7 +221,7 @@ with gr.Blocks(
255
  )
256
 
257
  chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot, chat_input])
258
- bot_msg = chat_msg.then(bot, [chatbot, chat_msg], chatbot, api_name="bot_response")
259
  bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input])
260
 
261
  submit_button.click(
 
5
  import soundfile as sf
6
  import tempfile
7
  import uuid
8
+ import transformers
9
  import torch
10
  import time
 
 
 
11
 
12
  from nemo.collections.asr.models import ASRModel
 
 
13
 
14
  SAMPLE_RATE = 16000 # Hz
15
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
 
24
  '''
25
  PLACEHOLDER = """
26
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
27
+ <img src="https://huggingface.co/spaces/VanYsa/MyAlexa/blob/main/MyAlexaLogo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
28
  <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
29
  </div>
30
  """
31
 
 
 
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ ### ASR model
35
  canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
36
  canary_model.eval()
37
 
 
41
  decoding_cfg.beam.beam_size = 1
42
  canary_model.change_decoding_strategy(decoding_cfg)
43
 
44
+ ### LLM model
45
+ pipeline = transformers.pipeline(
46
+ "text-generation",
47
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
48
+ model_kwargs={"torch_dtype": torch.bfloat16},
49
+ device=device
 
 
 
 
 
 
50
  )
51
 
 
 
 
 
 
 
 
 
 
52
  def convert_audio(audio_filepath, tmpdir, utt_id):
53
  """
54
  Convert all files to monochannel 16 kHz wav files.
 
121
  history.append((message, None))
122
  return history, gr.Textbox(value="", interactive=False)
123
 
124
+ def bot(history,message):
125
  """
126
  Prints the LLM's response in the chatbot
127
  """
128
+ response = bot_response(history, message)
129
  history[-1][1] = ""
130
  for character in response:
131
  history[-1][1] += character
132
  time.sleep(0.05)
133
  yield history
134
 
135
+ def bot_response(history, message):
136
+ """
137
+ Generates a response from the LLM model.
138
+ Temperature and top_p are set to 0.6 and 0.9 respectively.
139
+ """
140
+ messages = [
141
+ {"role": "system", "content": "You are a helpful AI assistant."},
142
+ {"role": "user", "content": message},
143
+ ]
144
+
145
+ prompt = pipeline.tokenizer.apply_chat_template(
146
+ messages,
147
+ tokenize=False,
148
+ add_generation_prompt=True
149
+ )
150
+
151
+ terminators = [
152
+ pipeline.tokenizer.eos_token_id,
153
+ pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
154
+ ]
155
+
156
+ outputs = pipeline(
157
+ prompt,
158
+ max_new_tokens=512,
159
+ eos_token_id=terminators,
160
+ do_sample=True,
161
+ temperature=0.6,
162
+ top_p=0.9,
163
+ )
164
+ print(outputs[0]["generated_text"][len(prompt):])
165
+ return outputs[0]["generated_text"][len(prompt):]
166
+
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  with gr.Blocks(
168
  title="MyAlexa",
169
  css="""
 
221
  )
222
 
223
  chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot, chat_input])
224
+ bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response")
225
  bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input])
226
 
227
  submit_button.click(