File size: 14,714 Bytes
e724e7e
 
 
a958ea7
e724e7e
c07dd66
e724e7e
7afe31a
 
289ad8b
7afe31a
 
 
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
 
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
e724e7e
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e724e7e
7afe31a
 
 
289ad8b
7afe31a
5c42f52
7afe31a
7dc0ac9
5c42f52
dbf60e3
de7876c
5c42f52
dbf60e3
 
 
 
 
 
5c42f52
 
 
 
fe65571
dbf60e3
5c42f52
dbf60e3
 
5c42f52
 
 
 
 
 
 
 
 
 
7dc0ac9
dbf60e3
7afe31a
 
 
 
289ad8b
 
de7876c
dbf60e3
 
 
 
289ad8b
 
dbf60e3
 
 
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
de7876c
 
dbf60e3
 
 
de7876c
dbf60e3
 
 
de7876c
 
 
dbf60e3
 
 
 
 
 
 
 
de7876c
dbf60e3
 
de7876c
dbf60e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
dbf60e3
de7876c
 
dbf60e3
de7876c
dbf60e3
de7876c
 
 
 
 
 
dbf60e3
 
 
 
 
 
 
 
 
de7876c
 
 
 
 
 
dbf60e3
 
519f37a
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c42f52
 
 
 
 
 
7dc0ac9
 
5c42f52
 
bd4a44f
5c42f52
7afe31a
 
7dc0ac9
de7876c
7afe31a
 
 
dbf60e3
5c42f52
 
 
 
7afe31a
 
 
 
 
5c42f52
7afe31a
 
 
 
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
 
 
 
7afe31a
289ad8b
7afe31a
 
 
289ad8b
7afe31a
289ad8b
 
7afe31a
289ad8b
 
7afe31a
 
 
 
e724e7e
 
 
 
 
7afe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
289ad8b
 
 
 
 
 
 
 
 
7afe31a
 
 
 
 
 
289ad8b
7afe31a
 
 
 
 
 
 
 
 
 
 
289ad8b
8d9d85d
289ad8b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
from fastrtc import (
    ReplyOnPause, AdditionalOutputs, Stream,
    audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import time
import numpy as np
import torch
import os
import tempfile
from transformers import (
    AutoModelForSpeechSeq2Seq, 
    AutoProcessor, 
    pipeline,
    AutoTokenizer, 
    AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile

# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Step 1: Audio transcription with Whisper
def load_asr_model():
    model_id = "openai/whisper-small"
    
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    )
    model.to(device)
    
    processor = AutoProcessor.from_pretrained(model_id)
    
    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=30,
        batch_size=16,
        return_timestamps=False,
        torch_dtype=torch_dtype,
        device=device,
    )

# Step 2: Text generation with a smaller LLM
def load_llm_model():
    model_id = "facebook/opt-1.3b"
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Print initial configuration
    print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}")
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True
    )
    
    # THE KEY FIX: Set pad token consistently in both tokenizer and model config
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
        # Define a special token with ID that doesn't conflict
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model.resize_token_embeddings(len(tokenizer))
        
        # Make sure model config has consistent pad token ID
        model.config.pad_token_id = tokenizer.pad_token_id
        
        # Important: Also set these token IDs in model config
        if hasattr(model.config, 'decoder_start_token_id') and model.config.decoder_start_token_id is None:
            model.config.decoder_start_token_id = tokenizer.pad_token_id
        
        print(f"Modified token IDs - PAD: {tokenizer.pad_token_id}, EOS: {tokenizer.eos_token_id}")
        print(f"Model config - PAD: {model.config.pad_token_id}, EOS: {model.config.eos_token_id}")
    
    # Double-check that model config has pad token ID set
    if not hasattr(model.config, 'pad_token_id') or model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id
    
    # Move model to the right device
    model.to(device)
    
    return model, tokenizer

# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
    """Convert text to speech using gTTS and ensure proper WAV format."""
    # Create absolute paths for temporary files
    temp_dir = tempfile.gettempdir()
    mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3")
    wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav")
    
    try:
        # Make sure text is not empty
        if not text or text.isspace():
            text = "I don't have a response for that."
        
        # Create gTTS object and save to MP3
        tts = gTTS(text=text, lang='en', slow=False)
        tts.save(mp3_filename)
        
        print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}")
        
        # Try multiple methods to convert MP3 to WAV
        wav_created = False
        
        # Method 1: Try ffmpeg (most reliable)
        try:
            import subprocess
            cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename]
            print(f"Running ffmpeg command: {' '.join(cmd)}")
            
            result = subprocess.run(
                cmd, 
                stdout=subprocess.PIPE, 
                stderr=subprocess.PIPE,
                check=True
            )
            
            if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                wav_created = True
            else:
                print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}")
            
        except Exception as e:
            print(f"ffmpeg conversion failed: {str(e)}")
        
        # Method 2: Try pydub if ffmpeg failed
        if not wav_created:
            try:
                from pydub import AudioSegment
                print("Converting MP3 to WAV using pydub...")
                sound = AudioSegment.from_mp3(mp3_filename)
                sound = sound.set_frame_rate(24000).set_channels(1)
                sound.export(wav_filename, format="wav")
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                else:
                    print(f"pydub ran but WAV file is missing or too small")
                
            except Exception as e:
                print(f"pydub conversion failed: {str(e)}")
        
        # Method 3: Direct WAV creation with gTTS-like library (last resort)
        if not wav_created:
            try:
                import numpy as np
                from scipy.io import wavfile
                
                print("Generating synthetic speech directly...")
                # Generate a simple speech-like tone pattern
                sample_rate = 24000
                duration = len(text) * 0.075  # Approx timing
                t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
                
                # Create a speech-like tone with some variation
                frequencies = [220, 440, 330, 550]
                audio = np.zeros_like(t)
                for i, freq in enumerate(frequencies):
                    audio += 0.2 * np.sin(2 * np.pi * freq * t + i)
                
                # Add some envelope
                envelope = np.ones_like(t)
                attack = int(0.01 * sample_rate)
                release = int(0.1 * sample_rate)
                envelope[:attack] = np.linspace(0, 1, attack)
                envelope[-release:] = np.linspace(1, 0, release)
                audio = audio * envelope
                
                # Normalize and convert to int16
                audio = audio / np.max(np.abs(audio))
                audio = (audio * 32767).astype(np.int16)
                
                # Save as WAV
                wavfile.write(wav_filename, sample_rate, audio)
                
                if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
                    print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}")
                    wav_created = True
                
            except Exception as e:
                print(f"Direct WAV creation failed: {str(e)}")
        
        # Read the WAV file if it was created
        if wav_created:
            try:
                # Add a small delay to ensure the file is fully written
                time.sleep(0.1)
                
                # Read WAV file with scipy
                print(f"Reading WAV file: {wav_filename}")
                sample_rate, audio_data = wavfile.read(wav_filename)
                
                # Convert to expected format
                audio_data = audio_data.reshape(1, -1).astype(np.int16)
                print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}")
                return (sample_rate, audio_data)
                
            except Exception as e:
                print(f"Error reading WAV file: {str(e)}")
        
        # If all else fails, generate a simple tone
        print("All methods failed. Falling back to synthetic audio tone")
        sample_rate = 24000
        duration_sec = max(1, len(text) * 0.1)
        tone_length = int(sample_rate * duration_sec)
        audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    except Exception as e:
        print(f"Unexpected error in text-to-speech: {str(e)}")
        # Generate a simple tone as last resort
        sample_rate = 24000
        audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate)
        audio_data = (audio_data * 32767).astype(np.int16)
        audio_data = audio_data.reshape(1, -1)
        return (sample_rate, audio_data)
        
    finally:
        # Clean up temporary files
        for filename in [mp3_filename, wav_filename]:
            try:
                if os.path.exists(filename):
                    os.remove(filename)
            except Exception as e:
                print(f"Failed to remove temporary file {filename}: {str(e)}")

# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()

print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()

# Chat history management
chat_history = []

def generate_response(prompt):
    # If chat history is empty, add a system message
    if not chat_history:
        chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
    
    # Add user message to history
    chat_history.append({"role": "user", "content": prompt})
    
    # Prepare input for the model
    full_prompt = ""
    for message in chat_history:
        if message["role"] == "system":
            full_prompt += f"System: {message['content']}\n"
        elif message["role"] == "user":
            full_prompt += f"User: {message['content']}\n"
        elif message["role"] == "assistant":
            full_prompt += f"Assistant: {message['content']}\n"
    
    full_prompt += "Assistant: "
    
    # Instead of using the tokenizer to create inputs with padding,
    # let's prepare the inputs differently:
    input_ids = llm_tokenizer.encode(full_prompt, return_tensors='pt')
    
    # Create attention mask manually (all 1's)
    attention_mask = torch.ones_like(input_ids)
    
    # Move to device
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    
    # Generate response with completely explicit parameters
    with torch.no_grad():
        output = llm_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=llm_tokenizer.pad_token_id,
            eos_token_id=llm_tokenizer.eos_token_id,
            use_cache=True,
            no_repeat_ngram_size=3
        )
    
    response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
    response_text = response_text.split("Assistant: ")[-1].strip()
    
    # Add assistant response to history  
    chat_history.append({"role": "assistant", "content": response_text})
    
    # Keep history at a reasonable size
    if len(chat_history) > 10:
        # Keep system message and last 9 exchanges
        chat_history.pop(1)
    
    return response_text

def response(audio: tuple[int, np.ndarray]):
    # Step 1: Convert audio to float32 before passing to ASR
    sample_rate, audio_data = audio
    
    # Convert int16 audio to float32
    audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0  # Normalize to [-1.0, 1.0]
    
    # Speech-to-Text with correct data type
    transcript = asr_pipeline({
        "sampling_rate": sample_rate, 
        "raw": audio_float32
    })
    
    prompt = transcript["text"]
    print(f"Transcribed: {prompt}")
    
    # Step 2: Generate text response
    response_text = generate_response(prompt)
    print(f"Response: {response_text}")
    
    # Step 3: Text-to-Speech using gTTS
    sample_rate, audio_array = gtts_text_to_speech(response_text)
    
    # Convert to expected format and yield chunks
    chunk_size = int(sample_rate * 0.2)  # 200ms chunks
    for i in range(0, audio_array.shape[1], chunk_size):
        chunk = audio_array[:, i:i+chunk_size]
        if chunk.size > 0:  # Ensure we don't yield empty chunks
            yield (sample_rate, chunk)

stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response),
)

# For testing without WebRTC
def demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Local Voice Chatbot")
        audio_input = gr.Audio(sources=["microphone"], type="numpy")
        audio_output = gr.Audio()
        
        def process_audio(audio):
            if audio is None:
                return None
            
            sample_rate, audio_array = audio
            
            # Convert to float32 for ASR
            audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
            
            transcript = asr_pipeline({
                "sampling_rate": sample_rate, 
                "raw": audio_float32
            })
            
            prompt = transcript["text"]
            print(f"Transcribed: {prompt}")
            
            response_text = generate_response(prompt)
            print(f"Response: {response_text}")
            
            sample_rate, audio_array = gtts_text_to_speech(response_text)
            return (sample_rate, audio_array[0])
        
        audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
    
    demo.launch()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
    args = parser.parse_args()
    # hugging face issues
    demo()
    
    # if args.demo:
    #     demo()
    # else:
    #     # For running with FastRTC
    #     # You would need to add your FastRTC server code here
    #     pass