File size: 15,797 Bytes
e724e7e
 
 
a958ea7
e724e7e
c07dd66
e724e7e
7afe31a
 
289ad8b
7afe31a
 
 
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
 
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
e724e7e
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e724e7e
7afe31a
 
 
289ad8b
7afe31a
190ab02
7afe31a
7dc0ac9
5c42f52
dbf60e3
de7876c
190ab02
 
 
 
 
 
 
 
 
dbf60e3
 
 
 
 
 
190ab02
 
5c42f52
190ab02
 
7dc0ac9
190ab02
 
 
 
 
7afe31a
 
190ab02
 
 
7afe31a
 
289ad8b
 
de7876c
8e6480a
 
 
 
dbf60e3
 
 
 
289ad8b
 
dbf60e3
 
 
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
de7876c
 
dbf60e3
 
 
de7876c
dbf60e3
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
 
 
de7876c
dbf60e3
 
de7876c
dbf60e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6480a
dbf60e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
 
 
 
 
 
dbf60e3
 
 
 
 
 
 
 
 
de7876c
 
 
 
 
 
dbf60e3
 
519f37a
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190ab02
7afe31a
 
 
 
 
 
 
 
 
 
 
190ab02
 
 
 
 
 
 
 
 
 
 
5c42f52
190ab02
 
7dc0ac9
190ab02
 
bd4a44f
190ab02
7afe31a
190ab02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7afe31a
190ab02
7afe31a
 
 
190ab02
7afe31a
 
190ab02
7afe31a
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
 
 
 
7afe31a
289ad8b
7afe31a
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
289ad8b
 
7afe31a
 
 
 
e724e7e
 
 
 
 
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
7afe31a
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
289ad8b
8d9d85d
289ad8b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
from fastrtc import (
    ReplyOnPause, AdditionalOutputs, Stream,
    audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import time
import numpy as np
import torch
import os
import tempfile
from transformers import (
    AutoModelForSpeechSeq2Seq, 
    AutoProcessor, 
    pipeline,
    AutoTokenizer, 
    AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile

# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Step 1: Audio transcription with Whisper
def load_asr_model():
    model_id = "openai/whisper-small"
    
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    )
    model.to(device)
    
    processor = AutoProcessor.from_pretrained(model_id)
    
    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=30,
        batch_size=16,
        return_timestamps=False,
        torch_dtype=torch_dtype,
        device=device,
    )

# Step 2: Text generation with a smaller LLM
def load_llm_model():
    model_id = "facebook/opt-1.3b"
    
    # Load tokenizer with special attention to the padding token
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Print initial configuration
    print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}")
    
    # For OPT models specifically - configure tokenizer before loading model
    if tokenizer.pad_token is None:
        # Use a completely different token as pad token - must be done before model loading
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        # Ensure pad token is really different from EOS token
        assert tokenizer.pad_token_id != tokenizer.eos_token_id, "Pad token still same as EOS token!"
        print(f"Added special PAD token with ID {tokenizer.pad_token_id} (different from EOS: {tokenizer.eos_token_id})")
    
    # Load model with the knowledge that tokenizer may have been modified
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True
    )
    
    # Resize embeddings to match tokenizer
    model.resize_token_embeddings(len(tokenizer))
    
    # CRITICAL: Make sure model config knows about the pad token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    # OPT models need this explicit configuration
    if hasattr(model.config, "word_embed_proj_dim"):
        model.config._remove_wrong_keys = False
    
    # Move model to device
    model.to(device)
    
    print(f"Final token setup - Pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
    print(f"Model config pad_token_id: {model.config.pad_token_id}")
    
    return model, tokenizer

# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
    """Convert text to speech using gTTS and ensure proper WAV format."""
    # Import numpy and wavfile at the function level to ensure they're available in all code paths
    import numpy as np
    from scipy.io import wavfile
    
    # Create absolute paths for temporary files
    temp_dir = tempfile.gettempdir()
    mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3")
    wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav")
    
    try:
        # Make sure text is not empty
        if not text or text.isspace():
            text = "I don't have a response for that."
        
        # Create gTTS object and save to MP3
        tts = gTTS(text=text, lang='en', slow=False)
        tts.save(mp3_filename)
        
        print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}")
        
        # Try multiple methods to convert MP3 to WAV
        wav_created = False
        
        # Method 1: Try ffmpeg (most reliable)
        try:
            import subprocess
            cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename]
            print(f"Running ffmpeg command: {' '.join(cmd)}")
            
            result = subprocess.run(
                cmd, 
                stdout=subprocess.PIPE, 
                stderr=subprocess.PIPE,
                check=True
            )
            
            if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                wav_created = True
            else:
                print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}")
            
        except Exception as e:
            print(f"ffmpeg conversion failed: {str(e)}")
        
        # Method 2: Try pydub if ffmpeg failed
        if not wav_created:
            try:
                from pydub import AudioSegment
                print("Converting MP3 to WAV using pydub...")
                sound = AudioSegment.from_mp3(mp3_filename)
                sound = sound.set_frame_rate(24000).set_channels(1)
                sound.export(wav_filename, format="wav")
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                else:
                    print(f"pydub ran but WAV file is missing or too small")
                
            except Exception as e:
                print(f"pydub conversion failed: {str(e)}")
        
        # Method 3: Direct WAV creation
        if not wav_created:
            try:
                print("Generating synthetic speech directly...")
                # Generate a simple speech-like tone pattern
                sample_rate = 24000
                duration = len(text) * 0.075  # Approx timing
                t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
                
                # Create a speech-like tone with some variation
                frequencies = [220, 440, 330, 550]
                audio = np.zeros_like(t)
                for i, freq in enumerate(frequencies):
                    audio += 0.2 * np.sin(2 * np.pi * freq * t + i)
                
                # Add some envelope
                envelope = np.ones_like(t)
                attack = int(0.01 * sample_rate)
                release = int(0.1 * sample_rate)
                envelope[:attack] = np.linspace(0, 1, attack)
                envelope[-release:] = np.linspace(1, 0, release)
                audio = audio * envelope
                
                # Normalize and convert to int16
                audio = audio / np.max(np.abs(audio))
                audio = (audio * 32767).astype(np.int16)
                
                # Save as WAV
                wavfile.write(wav_filename, sample_rate, audio)
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                
            except Exception as e:
                print(f"Direct WAV creation failed: {str(e)}")
        
        # Read the WAV file if it was created
        if wav_created:
            try:
                # Add a small delay to ensure the file is fully written
                time.sleep(0.1)
                
                # Read WAV file with scipy
                print(f"Reading WAV file: {wav_filename}")
                sample_rate, audio_data = wavfile.read(wav_filename)
                
                # Convert to expected format
                audio_data = audio_data.reshape(1, -1).astype(np.int16)
                print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}")
                return (sample_rate, audio_data)
                
            except Exception as e:
                print(f"Error reading WAV file: {str(e)}")
        
        # If all else fails, generate a simple tone
        print("All methods failed. Falling back to synthetic audio tone")
        sample_rate = 24000
        duration_sec = max(1, len(text) * 0.1)
        tone_length = int(sample_rate * duration_sec)
        audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    except Exception as e:
        print(f"Unexpected error in text-to-speech: {str(e)}")
        # Generate a simple tone as last resort
        sample_rate = 24000
        audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    finally:
        # Clean up temporary files
        for filename in [mp3_filename, wav_filename]:
            try:
                if os.path.exists(filename):
                    os.remove(filename)
            except Exception as e:
                print(f"Failed to remove temporary file {filename}: {str(e)}")

# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()

print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()

# Chat history management
chat_history = []

def generate_response(prompt):
    # If chat history is empty, add a system message
    if not chat_history:
        chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
    
    # Add user message to history
    chat_history.append({"role": "user", "content": prompt})
    
    # Build full prompt from chat history
    full_prompt = ""
    for message in chat_history:
        if message["role"] == "system":
            full_prompt += f"System: {message['content']}\n"
        elif message["role"] == "user":
            full_prompt += f"User: {message['content']}\n"
        elif message["role"] == "assistant":
            full_prompt += f"Assistant: {message['content']}\n"
    
    full_prompt += "Assistant: "
    
    # Use encode_plus which offers more control
    encoded_input = llm_tokenizer.encode_plus(
        full_prompt,
        return_tensors="pt",
        padding=False,  # Don't pad here - we'll handle it manually
        add_special_tokens=True,
        return_attention_mask=True
    )
    
    # Extract and move tensors to device
    input_ids = encoded_input["input_ids"].to(device)
    
    # Create attention mask explicitly - all 1s for a non-padded sequence
    attention_mask = torch.ones_like(input_ids).to(device)
    
    # Print for debugging
    print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}")
    
    # Generate with very explicit parameters for OPT models
    with torch.no_grad():
        try:
            output = llm_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,  # Explicitly pass attention mask
                max_new_tokens=128,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=llm_tokenizer.pad_token_id,  # Explicitly set pad token ID
                eos_token_id=llm_tokenizer.eos_token_id,  # Explicitly set EOS token ID
                use_cache=True,
                no_repeat_ngram_size=3,
                # Add these parameters specifically for OPT
                forced_bos_token_id=None,
                forced_eos_token_id=None,
                num_beams=1  # Simple greedy decoding with temperature
            )
            
        except Exception as e:
            print(f"Error during generation: {e}")
            # Fallback with simpler parameters
            output = llm_model.generate(
                input_ids=input_ids,
                max_new_tokens=128,
                do_sample=True,
                temperature=0.7
            )
    
    # Decode only the generated part (not the input)
    response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
    response_text = response_text.split("Assistant: ")[-1].strip()
    
    # Add assistant response to history
    chat_history.append({"role": "assistant", "content": response_text})
    
    # Keep history manageable
    if len(chat_history) > 10:
        # Keep system message and last 9 exchanges
        chat_history.pop(1)
    
    return response_text

def response(audio: tuple[int, np.ndarray]):
    # Step 1: Convert audio to float32 before passing to ASR
    sample_rate, audio_data = audio
    
    # Convert int16 audio to float32
    audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0  # Normalize to [-1.0, 1.0]
    
    # Speech-to-Text with correct data type
    transcript = asr_pipeline({
        "sampling_rate": sample_rate, 
        "raw": audio_float32
    })
    
    prompt = transcript["text"]
    print(f"Transcribed: {prompt}")
    
    # Step 2: Generate text response
    response_text = generate_response(prompt)
    print(f"Response: {response_text}")
    
    # Step 3: Text-to-Speech using gTTS
    sample_rate, audio_array = gtts_text_to_speech(response_text)
    
    # Convert to expected format and yield chunks
    chunk_size = int(sample_rate * 0.2)  # 200ms chunks
    for i in range(0, audio_array.shape[1], chunk_size):
        chunk = audio_array[:, i:i+chunk_size]
        if chunk.size > 0:  # Ensure we don't yield empty chunks
            yield (sample_rate, chunk)

stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response),
)

# For testing without WebRTC
def demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Local Voice Chatbot")
        audio_input = gr.Audio(sources=["microphone"], type="numpy")
        audio_output = gr.Audio()
        
        def process_audio(audio):
            if audio is None:
                return None
            
            sample_rate, audio_array = audio
            
            # Convert to float32 for ASR
            audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
            
            transcript = asr_pipeline({
                "sampling_rate": sample_rate, 
                "raw": audio_float32
            })
            
            prompt = transcript["text"]
            print(f"Transcribed: {prompt}")
            
            response_text = generate_response(prompt)
            print(f"Response: {response_text}")
            
            sample_rate, audio_array = gtts_text_to_speech(response_text)
            return (sample_rate, audio_array[0])
        
        audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
    
    demo.launch()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
    args = parser.parse_args()
    # hugging face issues
    demo()
    
    # if args.demo:
    #     demo()
    # else:
    #     # For running with FastRTC
    #     # You would need to add your FastRTC server code here
    #     pass