File size: 14,639 Bytes
e724e7e
 
 
a958ea7
e724e7e
 
7afe31a
 
289ad8b
7afe31a
 
 
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
 
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
e724e7e
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e724e7e
7afe31a
 
 
289ad8b
7afe31a
dbf60e3
7afe31a
7dc0ac9
dbf60e3
 
de7876c
dbf60e3
 
 
 
 
 
 
 
 
 
 
 
de7876c
dbf60e3
fe65571
dbf60e3
 
 
 
 
 
fe65571
dbf60e3
 
7dc0ac9
dbf60e3
7afe31a
 
 
 
289ad8b
 
de7876c
dbf60e3
 
 
 
289ad8b
 
dbf60e3
 
 
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
de7876c
 
dbf60e3
 
 
de7876c
dbf60e3
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
 
 
de7876c
dbf60e3
 
de7876c
dbf60e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
 
 
 
 
 
dbf60e3
 
 
 
 
 
 
 
 
de7876c
 
 
 
 
 
dbf60e3
 
519f37a
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dc0ac9
dbf60e3
fe65571
 
 
dbf60e3
 
 
de7876c
fe65571
7dc0ac9
 
 
fe65571
bd4a44f
dbf60e3
7afe31a
 
7dc0ac9
de7876c
7afe31a
 
 
dbf60e3
 
 
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
 
 
 
7afe31a
289ad8b
7afe31a
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
289ad8b
 
7afe31a
 
 
 
e724e7e
 
 
 
 
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
7afe31a
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
289ad8b
8d9d85d
289ad8b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
from fastrtc import (
    ReplyOnPause, AdditionalOutputs, Stream,
    audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import numpy as np
import torch
import os
import tempfile
from transformers import (
    AutoModelForSpeechSeq2Seq, 
    AutoProcessor, 
    pipeline,
    AutoTokenizer, 
    AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile

# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Step 1: Audio transcription with Whisper
def load_asr_model():
    model_id = "openai/whisper-small"
    
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    )
    model.to(device)
    
    processor = AutoProcessor.from_pretrained(model_id)
    
    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=30,
        batch_size=16,
        return_timestamps=False,
        torch_dtype=torch_dtype,
        device=device,
    )

# Step 2: Text generation with a smaller LLM
def load_llm_model():
    model_id = "facebook/opt-1.3b"
    
    # First load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Print current token configuration
    print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}")
    
    # Load the model first
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True
    )
    
    # Set pad token if needed
    if tokenizer.pad_token is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
        # Add a new special token as padding token
        special_tokens = {'pad_token': '[PAD]'}
        num_added = tokenizer.add_special_tokens(special_tokens)
        
        # Must resize the token embeddings when adding tokens
        model.resize_token_embeddings(len(tokenizer))
        
        # Update the model's config to explicitly set the pad token ID
        model.config.pad_token_id = tokenizer.pad_token_id
        
        print(f"Added pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
        print(f"Different from EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
    else:
        print(f"Pad token already set: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
        print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
    
    # Move model to the right device
    model.to(device)
    
    return model, tokenizer

# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
    """Convert text to speech using gTTS and ensure proper WAV format."""
    # Create absolute paths for temporary files
    temp_dir = tempfile.gettempdir()
    mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3")
    wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav")
    
    try:
        # Make sure text is not empty
        if not text or text.isspace():
            text = "I don't have a response for that."
        
        # Create gTTS object and save to MP3
        tts = gTTS(text=text, lang='en', slow=False)
        tts.save(mp3_filename)
        
        print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}")
        
        # Try multiple methods to convert MP3 to WAV
        wav_created = False
        
        # Method 1: Try ffmpeg (most reliable)
        try:
            import subprocess
            cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename]
            print(f"Running ffmpeg command: {' '.join(cmd)}")
            
            result = subprocess.run(
                cmd, 
                stdout=subprocess.PIPE, 
                stderr=subprocess.PIPE,
                check=True
            )
            
            if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                wav_created = True
            else:
                print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}")
            
        except Exception as e:
            print(f"ffmpeg conversion failed: {str(e)}")
        
        # Method 2: Try pydub if ffmpeg failed
        if not wav_created:
            try:
                from pydub import AudioSegment
                print("Converting MP3 to WAV using pydub...")
                sound = AudioSegment.from_mp3(mp3_filename)
                sound = sound.set_frame_rate(24000).set_channels(1)
                sound.export(wav_filename, format="wav")
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                else:
                    print(f"pydub ran but WAV file is missing or too small")
                
            except Exception as e:
                print(f"pydub conversion failed: {str(e)}")
        
        # Method 3: Direct WAV creation with gTTS-like library (last resort)
        if not wav_created:
            try:
                import numpy as np
                from scipy.io import wavfile
                
                print("Generating synthetic speech directly...")
                # Generate a simple speech-like tone pattern
                sample_rate = 24000
                duration = len(text) * 0.075  # Approx timing
                t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
                
                # Create a speech-like tone with some variation
                frequencies = [220, 440, 330, 550]
                audio = np.zeros_like(t)
                for i, freq in enumerate(frequencies):
                    audio += 0.2 * np.sin(2 * np.pi * freq * t + i)
                
                # Add some envelope
                envelope = np.ones_like(t)
                attack = int(0.01 * sample_rate)
                release = int(0.1 * sample_rate)
                envelope[:attack] = np.linspace(0, 1, attack)
                envelope[-release:] = np.linspace(1, 0, release)
                audio = audio * envelope
                
                # Normalize and convert to int16
                audio = audio / np.max(np.abs(audio))
                audio = (audio * 32767).astype(np.int16)
                
                # Save as WAV
                wavfile.write(wav_filename, sample_rate, audio)
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                
            except Exception as e:
                print(f"Direct WAV creation failed: {str(e)}")
        
        # Read the WAV file if it was created
        if wav_created:
            try:
                # Add a small delay to ensure the file is fully written
                time.sleep(0.1)
                
                # Read WAV file with scipy
                print(f"Reading WAV file: {wav_filename}")
                sample_rate, audio_data = wavfile.read(wav_filename)
                
                # Convert to expected format
                audio_data = audio_data.reshape(1, -1).astype(np.int16)
                print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}")
                return (sample_rate, audio_data)
                
            except Exception as e:
                print(f"Error reading WAV file: {str(e)}")
        
        # If all else fails, generate a simple tone
        print("All methods failed. Falling back to synthetic audio tone")
        sample_rate = 24000
        duration_sec = max(1, len(text) * 0.1)
        tone_length = int(sample_rate * duration_sec)
        audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    except Exception as e:
        print(f"Unexpected error in text-to-speech: {str(e)}")
        # Generate a simple tone as last resort
        sample_rate = 24000
        audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    finally:
        # Clean up temporary files
        for filename in [mp3_filename, wav_filename]:
            try:
                if os.path.exists(filename):
                    os.remove(filename)
            except Exception as e:
                print(f"Failed to remove temporary file {filename}: {str(e)}")

# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()

print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()

# Chat history management
chat_history = []

def generate_response(prompt):
    # If chat history is empty, add a system message
    if not chat_history:
        chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
    
    # Add user message to history
    chat_history.append({"role": "user", "content": prompt})
    
    # Prepare input for the model
    full_prompt = ""
    for message in chat_history:
        if message["role"] == "system":
            full_prompt += f"System: {message['content']}\n"
        elif message["role"] == "user":
            full_prompt += f"User: {message['content']}\n"
        elif message["role"] == "assistant":
            full_prompt += f"Assistant: {message['content']}\n"
    
    full_prompt += "Assistant: "
    
    # Generate response with proper attention mask
    # Ensure padding is done correctly with explicit parameters
    tokenized_inputs = llm_tokenizer(
        full_prompt, 
        return_tensors="pt", 
        padding="max_length",
        max_length=512,  # Fixed length helps with attention masks
        truncation=True,
        return_attention_mask=True
    )
    
    # Move to device
    input_ids = tokenized_inputs["input_ids"].to(device)
    attention_mask = tokenized_inputs["attention_mask"].to(device)
    
    # Generate response - explicitly pass all needed parameters
    with torch.no_grad():
        output = llm_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=llm_tokenizer.pad_token_id,  # Explicitly set pad token ID
            eos_token_id=llm_tokenizer.eos_token_id   # Explicitly set EOS token ID
        )
    
    response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
    response_text = response_text.split("Assistant: ")[-1].strip()
    
    # Add assistant response to history
    chat_history.append({"role": "assistant", "content": response_text})
    
    # Keep history at a reasonable size
    if len(chat_history) > 10:
        # Keep system message and last 9 exchanges
        chat_history.pop(1)
    
    return response_text

def response(audio: tuple[int, np.ndarray]):
    # Step 1: Convert audio to float32 before passing to ASR
    sample_rate, audio_data = audio
    
    # Convert int16 audio to float32
    audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0  # Normalize to [-1.0, 1.0]
    
    # Speech-to-Text with correct data type
    transcript = asr_pipeline({
        "sampling_rate": sample_rate, 
        "raw": audio_float32
    })
    
    prompt = transcript["text"]
    print(f"Transcribed: {prompt}")
    
    # Step 2: Generate text response
    response_text = generate_response(prompt)
    print(f"Response: {response_text}")
    
    # Step 3: Text-to-Speech using gTTS
    sample_rate, audio_array = gtts_text_to_speech(response_text)
    
    # Convert to expected format and yield chunks
    chunk_size = int(sample_rate * 0.2)  # 200ms chunks
    for i in range(0, audio_array.shape[1], chunk_size):
        chunk = audio_array[:, i:i+chunk_size]
        if chunk.size > 0:  # Ensure we don't yield empty chunks
            yield (sample_rate, chunk)

stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response),
)

# For testing without WebRTC
def demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Local Voice Chatbot")
        audio_input = gr.Audio(sources=["microphone"], type="numpy")
        audio_output = gr.Audio()
        
        def process_audio(audio):
            if audio is None:
                return None
            
            sample_rate, audio_array = audio
            
            # Convert to float32 for ASR
            audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
            
            transcript = asr_pipeline({
                "sampling_rate": sample_rate, 
                "raw": audio_float32
            })
            
            prompt = transcript["text"]
            print(f"Transcribed: {prompt}")
            
            response_text = generate_response(prompt)
            print(f"Response: {response_text}")
            
            sample_rate, audio_array = gtts_text_to_speech(response_text)
            return (sample_rate, audio_array[0])
        
        audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
    
    demo.launch()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
    args = parser.parse_args()
    # hugging face issues
    demo()
    
    # if args.demo:
    #     demo()
    # else:
    #     # For running with FastRTC
    #     # You would need to add your FastRTC server code here
    #     pass