File size: 11,285 Bytes
ee78b3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import tempfile
import torch
import librosa
import numpy as np
from transformers import AutoModelForImageTextToText, AutoProcessor
from huggingface_hub import login
import io
from pydub import AudioSegment
import base64
import traceback


class Gemma3nInference:
    def __init__(self, device='cuda:0'):
        self.device = device
        
        # Login to Hugging Face using token from environment
        hf_token = os.getenv('HF_TOKEN')
        if hf_token:
            login(token=hf_token)
        else:
            print("Warning: HF_TOKEN not found in environment variables")
        
        print("Loading Gemma 3n model...")
        try:
            # Try loading Gemma 3n E2B (2B effective params) using the correct class
            model_name = "google/gemma-3n-E2B-it"
            self.model = AutoModelForImageTextToText.from_pretrained(
                model_name,
                torch_dtype="auto",  # Let it auto-detect the best dtype
                device_map="auto",
                trust_remote_code=True
            )
            self.processor = AutoProcessor.from_pretrained(model_name)
            print(f"Gemma 3n E2B model loaded successfully on device: {self.model.device}")
            print(f"Model dtype: {self.model.dtype}")
        except Exception as e:
            print(f"Error loading Gemma 3n model: {e}")
            print("Trying alternative loading method...")
            try:
                # Try loading without vision components initially
                from transformers import AutoConfig
                config = AutoConfig.from_pretrained(
                    model_name,
                    trust_remote_code=True
                )
                # Disable vision tower if causing issues
                if hasattr(config, 'vision_config'):
                    print("Attempting to load without problematic vision config...")
                
                self.model = AutoModelForImageTextToText.from_pretrained(
                    model_name,
                    torch_dtype="auto",
                    trust_remote_code=True,
                    ignore_mismatched_sizes=True
                ).to(self.device)
                self.processor = AutoProcessor.from_pretrained(
                    model_name,
                    trust_remote_code=True
                )
                print("Gemma 3n E2B model loaded with alternative method")
            except Exception as e2:
                print(f"Alternative loading also failed: {e2}")
                raise e2

    def preprocess_audio(self, audio_path):
        """Convert audio to Gemma 3n format: 16kHz mono float32 in range [-1, 1]"""
        try:
            # Load audio file
            audio, sr = librosa.load(audio_path, sr=16000, mono=True)
            
            # Ensure audio is in range [-1, 1]
            if audio.max() > 1.0 or audio.min() < -1.0:
                audio = audio / max(abs(audio.max()), abs(audio.min()))
            
            # Limit to 30 seconds as recommended
            max_samples = 30 * 16000
            if len(audio) > max_samples:
                audio = audio[:max_samples]
            
            return audio.astype(np.float32)
        except Exception as e:
            print(f"Error preprocessing audio: {e}")
            raise

    def create_multimodal_input(self, audio_path, text_prompt="Respond naturally to this audio input"):
        """Create multimodal input for Gemma 3n using the same format as the notebook"""
        try:
            # Preprocess audio
            audio_array = self.preprocess_audio(audio_path)
            
            # Create multimodal message format exactly like the notebook
            message = {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": audio_path},  # Use path instead of array
                    {"type": "text", "text": text_prompt}
                ]
            }
            
            # Process with Gemma 3n processor using the notebook approach
            inputs = self.processor.apply_chat_template(
                [message],  # History is a list
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            )
            
            return inputs.to(self.device, dtype=self.model.dtype)
        except Exception as e:
            print(f"Error creating multimodal input: {e}")
            traceback.print_exc()
            raise

    def generate_response(self, audio_path, max_new_tokens=256):
        """Generate text response from audio input using notebook approach"""
        try:
            # Create multimodal input
            inputs = self.create_multimodal_input(audio_path)
            input_len = inputs["input_ids"].shape[-1]
            
            # Generate response exactly like the notebook
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    disable_compile=True
                )
            
            # Decode response exactly like the notebook
            text = self.processor.batch_decode(
                outputs[:, input_len:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            
            return text[0].strip() if text else "No response generated"
        except Exception as e:
            print(f"Error generating response: {e}")
            traceback.print_exc()
            return f"Error: {str(e)}"

    def stream_response(self, audio_path, max_new_tokens=512, temperature=0.9):
        """Generate streaming text response from audio input"""
        try:
            # Create multimodal input
            inputs = self.create_multimodal_input(audio_path)
            
            # Generate streaming response
            with torch.no_grad():
                # Use the model's generate method with streaming
                streamer = self.processor.tokenizer
                
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    do_sample=True,
                    pad_token_id=self.processor.tokenizer.eos_token_id,
                    eos_token_id=self.processor.tokenizer.eos_token_id,
                    return_dict_in_generate=True,
                    output_scores=True
                )
                
                # Decode the full response
                response = self.processor.tokenizer.decode(
                    outputs.sequences[0][inputs['input_ids'].shape[1]:], 
                    skip_special_tokens=True
                )
                
                return response.strip()
                
        except Exception as e:
            print(f"Error in streaming response: {e}")
            traceback.print_exc()
            return f"Error: {str(e)}"

    def text_to_speech_simple(self, text):
        """Convert text to speech using gTTS"""
        try:
            from gtts import gTTS
            
            # Create TTS object
            tts = gTTS(text=text, lang='en', slow=False)
            
            # Save to temporary file
            with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
                tts.save(tmp_file.name)
                
                # Convert MP3 to WAV format that the system expects
                audio_segment = AudioSegment.from_mp3(tmp_file.name)
                
                # Convert to expected format (24kHz, mono, 16-bit)
                audio_segment = audio_segment.set_frame_rate(24000)
                audio_segment = audio_segment.set_channels(1)
                audio_segment = audio_segment.set_sample_width(2)
                
                # Export to WAV bytes
                audio_buffer = io.BytesIO()
                audio_segment.export(audio_buffer, format="wav")
                
                # Clean up temp file
                os.unlink(tmp_file.name)
                
                return audio_buffer.getvalue()
            
        except ImportError:
            print("gTTS not available, falling back to silence")
            # Fallback to silence if gTTS not installed
            duration_seconds = max(1, len(text) / 20)
            sample_rate = 24000
            samples = int(duration_seconds * sample_rate)
            audio_data = np.zeros(samples, dtype=np.int16)
            audio_segment = AudioSegment(
                audio_data.tobytes(),
                frame_rate=sample_rate,
                sample_width=2,
                channels=1
            )
            audio_buffer = io.BytesIO()
            audio_segment.export(audio_buffer, format="wav")
            return audio_buffer.getvalue()
            
        except Exception as e:
            print(f"Error in TTS: {e}")
            # Return minimal audio data on error
            return b'\x00' * 1024

    def process_audio_stream(self, audio_bytes):
        """Process audio stream and return response audio stream"""
        try:
            # Decode base64 audio
            audio_data = base64.b64decode(audio_bytes)
            
            # Save to temporary file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                f.write(audio_data)
                temp_audio_path = f.name
            
            try:
                # Generate text response
                text_response = self.generate_response(temp_audio_path)
                print(f"Generated response: {text_response}")
                
                # Convert to speech (placeholder)
                audio_response = self.text_to_speech_simple(text_response)
                
                return audio_response
                
            finally:
                # Clean up temp file
                if os.path.exists(temp_audio_path):
                    os.unlink(temp_audio_path)
                    
        except Exception as e:
            print(f"Error processing audio stream: {e}")
            traceback.print_exc()
            # Return minimal audio data on error
            return b'\x00' * 1024

    def warm_up(self):
        """Warm up the model"""
        try:
            print("Warming up Gemma 3n model...")
            # Create a short dummy audio
            dummy_audio = np.zeros(16000, dtype=np.float32)  # 1 second of silence
            
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                # Save dummy audio
                import soundfile as sf
                sf.write(f.name, dummy_audio, 16000)
                
                # Generate a quick response
                response = self.generate_response(f.name, max_new_tokens=10)
                print(f"Warm-up response: {response}")
                
                # Clean up
                os.unlink(f.name)
            
            print("Gemma 3n warm-up complete")
        except Exception as e:
            print(f"Error during warm-up: {e}")
            # Don't fail startup on warm-up errors
            pass