hashhac commited on
Commit
6218f6a
·
1 Parent(s): fdd081d

try 2 orion time

Browse files
Files changed (2) hide show
  1. app.py +43 -220
  2. requirements.txt +8 -8
app.py CHANGED
@@ -1,197 +1,35 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- import os
5
- import tempfile
6
- from transformers import (
7
- AutoModelForSpeechSeq2Seq,
8
- AutoProcessor,
9
- pipeline,
10
- AutoTokenizer,
11
- AutoModelForCausalLM
12
- )
13
 
14
  # Check if CUDA is available, otherwise use CPU
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
18
- # Initialize pyttsx3 for local TTS
19
- def load_local_tts():
20
- import pyttsx3
21
-
22
- engine = pyttsx3.init()
23
- engine.setProperty('rate', 150) # Speed of speech
24
- engine.setProperty('volume', 0.9) # Volume
25
-
26
- voices = engine.getProperty('voices')
27
- if len(voices) > 1:
28
- engine.setProperty('voice', voices[1].id) # Set female voice
29
-
30
- return engine
31
-
32
- # Initialize the TTS engine
33
- print("Loading local TTS engine...")
34
- tts_engine = load_local_tts()
35
-
36
- def text_to_speech_local(text):
37
- """Convert text to speech using pyttsx3 local TTS engine"""
38
- import tempfile
39
- import soundfile as sf
40
-
41
- # Create a temporary file to store the audio
42
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
43
- temp_filename = temp_file.name
44
-
45
- # Generate speech to the temporary file
46
- tts_engine.save_to_file(text, temp_filename)
47
- tts_engine.runAndWait()
48
-
49
- # Read the audio file
50
- audio_data, sample_rate = sf.read(temp_filename)
51
-
52
- # Convert to the expected format
53
- if len(audio_data.shape) == 1:
54
- audio_data = audio_data.reshape(1, -1)
55
- else:
56
- audio_data = audio_data[:, 0].reshape(1, -1)
57
-
58
- # Ensure it's int16
59
- audio_data = (audio_data * 32767).astype(np.int16)
60
-
61
- # Clean up the temporary file
62
- os.unlink(temp_filename)
63
-
64
- return (sample_rate, audio_data)
65
-
66
- # Load ASR model (Whisper)
67
- def load_asr_model():
68
- model_id = "openai/whisper-small"
69
-
70
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
71
- model_id,
72
- torch_dtype=torch_dtype,
73
- low_cpu_mem_usage=True,
74
- use_safetensors=True
75
- )
76
- model.to(device)
77
-
78
- processor = AutoProcessor.from_pretrained(model_id)
79
-
80
- return pipeline(
81
- "automatic-speech-recognition",
82
- model=model,
83
- tokenizer=processor.tokenizer,
84
- feature_extractor=processor.feature_extractor,
85
- max_new_tokens=128,
86
- chunk_length_s=30,
87
- batch_size=16,
88
- return_timestamps=False,
89
- torch_dtype=torch_dtype,
90
- device=device,
91
- )
92
-
93
- # Load LLM model
94
- def load_llm_model():
95
- model_id = "facebook/opt-1.3b"
96
-
97
- tokenizer = AutoTokenizer.from_pretrained(model_id)
98
-
99
- if tokenizer.pad_token is None:
100
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
101
-
102
- model = AutoModelForCausalLM.from_pretrained(
103
- model_id,
104
- torch_dtype=torch_dtype,
105
- low_cpu_mem_usage=True
106
- )
107
-
108
- model.resize_token_embeddings(len(tokenizer))
109
- model.config.pad_token_id = tokenizer.pad_token_id
110
-
111
- if hasattr(model.config, "word_embed_proj_dim"):
112
- model.config._remove_wrong_keys = False
113
-
114
- model.to(device)
115
-
116
- return model, tokenizer
117
-
118
- # Initialize models
119
- print("Loading ASR model...")
120
- asr_pipeline = load_asr_model()
121
 
122
- print("Loading LLM model...")
123
- llm_model, llm_tokenizer = load_llm_model()
124
-
125
- # Chat history management
126
- chat_history = []
 
 
 
127
 
128
- def generate_response(prompt):
129
- # If chat history is empty, add a system message
130
- if not chat_history:
131
- chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
132
-
133
- # Add user message to history
134
- chat_history.append({"role": "user", "content": prompt})
135
-
136
- # Build prompt from chat history
137
- full_prompt = ""
138
- for message in chat_history:
139
- if message["role"] == "system":
140
- full_prompt += f"System: {message['content']}\n"
141
- elif message["role"] == "user":
142
- full_prompt += f"User: {message['content']}\n"
143
- elif message["role"] == "assistant":
144
- full_prompt += f"Assistant: {message['content']}\n"
145
-
146
- full_prompt += "Assistant: "
147
-
148
- # Encode input
149
- encoded_input = llm_tokenizer.encode_plus(
150
- full_prompt,
151
- return_tensors="pt",
152
- padding=False,
153
- add_special_tokens=True,
154
- return_attention_mask=True
155
- )
156
-
157
- input_ids = encoded_input["input_ids"].to(device)
158
- attention_mask = torch.ones_like(input_ids).to(device)
159
-
160
- # Generate response
161
  with torch.no_grad():
162
- try:
163
- output = llm_model.generate(
164
- input_ids=input_ids,
165
- attention_mask=attention_mask,
166
- max_new_tokens=128,
167
- do_sample=True,
168
- temperature=0.7,
169
- top_p=0.9,
170
- pad_token_id=llm_tokenizer.pad_token_id,
171
- eos_token_id=llm_tokenizer.eos_token_id,
172
- use_cache=True
173
- )
174
- except Exception as e:
175
- output = llm_model.generate(
176
- input_ids=input_ids,
177
- max_new_tokens=128,
178
- do_sample=True,
179
- temperature=0.7
180
- )
181
-
182
- # Decode output
183
- response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
184
- response_text = response_text.split("Assistant: ")[-1].strip()
185
-
186
- # Add assistant response to history
187
- chat_history.append({"role": "assistant", "content": response_text})
188
-
189
- # Keep history manageable
190
- if len(chat_history) > 10:
191
- chat_history.pop(1)
192
-
193
- return response_text
194
 
 
195
  def demo():
196
  with gr.Blocks() as demo:
197
  gr.Markdown("# Voice Chatbot")
@@ -205,55 +43,40 @@ def demo():
205
  if audio is None:
206
  return None, "No audio detected."
207
 
208
- # Track conversation for display
209
- conversation_text = ""
210
-
211
- # Process audio
212
- sample_rate, audio_array = audio
213
-
214
- # Convert to float32 for ASR
215
- audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
216
 
217
  # Speech-to-text
218
- transcript = asr_pipeline({
219
- "sampling_rate": sample_rate,
220
- "raw": audio_float32
221
- })
222
 
223
- prompt = transcript["text"]
224
- conversation_text += f"You: {prompt}\n"
225
- print(f"Transcribed: {prompt}")
226
-
227
- # Generate response
228
- response_text = generate_response(prompt)
229
- conversation_text += f"Assistant: {response_text}\n"
230
  print(f"Response: {response_text}")
231
 
232
- # Convert to speech
233
- sample_rate, audio_array = text_to_speech_local(response_text)
 
 
 
 
 
 
 
 
234
 
235
- # Concatenate chunks for Gradio
236
- full_audio = np.concatenate([audio_array[:, i:i+int(sample_rate*0.2)]
237
- for i in range(0, audio_array.shape[1], int(sample_rate*0.2))
238
- if audio_array[:, i:i+int(sample_rate*0.2)].size > 0], axis=1)
239
 
240
- return (sample_rate, full_audio), conversation_text
241
 
242
  audio_input.change(process_audio,
243
- inputs=[audio_input],
244
- outputs=[audio_output, transcript_display])
245
 
246
  clear_btn = gr.Button("Clear Conversation")
247
  clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
248
-
249
- # Add function to clear chat history
250
- def reset_chat():
251
- global chat_history
252
- chat_history = []
253
- return None, "Conversation history cleared."
254
-
255
- reset_btn = gr.Button("Reset Chat History")
256
- reset_btn.click(reset_chat, outputs=[audio_output, transcript_display])
257
 
258
  demo.launch()
259
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5ForSpeechToText
5
+ from datasets import load_dataset
6
+ import soundfile as sf
 
 
 
 
 
 
7
 
8
  # Check if CUDA is available, otherwise use CPU
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
+ # Load SpeechT5 models and processor
12
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
13
+ asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)
14
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Function to convert speech to text
17
+ def speech_to_text(audio):
18
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(device)
19
+ with torch.no_grad():
20
+ logits = asr_model(inputs).logits
21
+ predicted_ids = torch.argmax(logits, dim=-1)
22
+ transcription = processor.batch_decode(predicted_ids)[0]
23
+ return transcription
24
 
25
+ # Function to convert text to speech
26
+ def text_to_speech(text):
27
+ inputs = processor(text, return_tensors="pt").input_ids.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with torch.no_grad():
29
+ speech = tts_model.generate_speech(inputs)
30
+ return speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Gradio demo
33
  def demo():
34
  with gr.Blocks() as demo:
35
  gr.Markdown("# Voice Chatbot")
 
43
  if audio is None:
44
  return None, "No audio detected."
45
 
46
+ # Convert audio to the correct format
47
+ sample_rate, audio_data = audio
48
+ audio_data = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0]
 
 
 
 
 
49
 
50
  # Speech-to-text
51
+ transcript = speech_to_text(audio_data)
52
+ print(f"Transcribed: {transcript}")
 
 
53
 
54
+ # Generate response (for simplicity, echo the transcript)
55
+ response_text = transcript
 
 
 
 
 
56
  print(f"Response: {response_text}")
57
 
58
+ # Text-to-speech
59
+ response_audio = text_to_speech(response_text)
60
+
61
+ # Save the response audio to a temporary file
62
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
63
+ sf.write(temp_file.name, response_audio.cpu().numpy(), 16000)
64
+ temp_filename = temp_file.name
65
+
66
+ # Read the audio file
67
+ audio_data, sample_rate = sf.read(temp_filename)
68
 
69
+ # Clean up the temporary file
70
+ os.unlink(temp_filename)
 
 
71
 
72
+ return (sample_rate, audio_data), f"You: {transcript}\nAssistant: {response_text}"
73
 
74
  audio_input.change(process_audio,
75
+ inputs=[audio_input],
76
+ outputs=[audio_output, transcript_display])
77
 
78
  clear_btn = gr.Button("Clear Conversation")
79
  clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
 
 
 
 
 
 
 
 
 
80
 
81
  demo.launch()
82
 
requirements.txt CHANGED
@@ -1,16 +1,16 @@
1
- transformers
2
- torch
3
- datasets
4
- scipy
5
- fastrtc
6
  gradio
7
  accelerate
8
- sentencepiece
9
  fastrtc[vad,tts]
10
  torchaudio
11
  gtts
12
- pydub
13
  scipy
14
- pyttsx3
15
  soundfile
16
  py-espeak-ng
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ scipy
5
+ fastrtc
6
  gradio
7
  accelerate
8
+ sentencepiece
9
  fastrtc[vad,tts]
10
  torchaudio
11
  gtts
12
+ pydub
13
  scipy
14
+ pyttsx3
15
  soundfile
16
  py-espeak-ng