import os import tempfile import torch import librosa import numpy as np from transformers import AutoModelForImageTextToText, AutoProcessor from huggingface_hub import login import io from pydub import AudioSegment import base64 import traceback class Gemma3nInference: def __init__(self, device='cuda:0'): self.device = device # Login to Hugging Face using token from environment hf_token = os.getenv('HF_TOKEN') if hf_token: login(token=hf_token) else: print("Warning: HF_TOKEN not found in environment variables") print("Loading Gemma 3n model...") try: # Try loading Gemma 3n E2B (2B effective params) using the correct class model_name = "google/gemma-3n-E2B-it" self.model = AutoModelForImageTextToText.from_pretrained( model_name, torch_dtype="auto", # Let it auto-detect the best dtype device_map="auto", trust_remote_code=True ) self.processor = AutoProcessor.from_pretrained(model_name) print(f"Gemma 3n E2B model loaded successfully on device: {self.model.device}") print(f"Model dtype: {self.model.dtype}") except Exception as e: print(f"Error loading Gemma 3n model: {e}") print("Trying alternative loading method...") try: # Try loading without vision components initially from transformers import AutoConfig config = AutoConfig.from_pretrained( model_name, trust_remote_code=True ) # Disable vision tower if causing issues if hasattr(config, 'vision_config'): print("Attempting to load without problematic vision config...") self.model = AutoModelForImageTextToText.from_pretrained( model_name, torch_dtype="auto", trust_remote_code=True, ignore_mismatched_sizes=True ).to(self.device) self.processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) print("Gemma 3n E2B model loaded with alternative method") except Exception as e2: print(f"Alternative loading also failed: {e2}") raise e2 def preprocess_audio(self, audio_path): """Convert audio to Gemma 3n format: 16kHz mono float32 in range [-1, 1]""" try: # Load audio file audio, sr = librosa.load(audio_path, sr=16000, mono=True) # Ensure audio is in range [-1, 1] if audio.max() > 1.0 or audio.min() < -1.0: audio = audio / max(abs(audio.max()), abs(audio.min())) # Limit to 30 seconds as recommended max_samples = 30 * 16000 if len(audio) > max_samples: audio = audio[:max_samples] return audio.astype(np.float32) except Exception as e: print(f"Error preprocessing audio: {e}") raise def create_multimodal_input(self, audio_path, text_prompt="Respond naturally to this audio input"): """Create multimodal input for Gemma 3n using the same format as the notebook""" try: # Preprocess audio audio_array = self.preprocess_audio(audio_path) # Create multimodal message format exactly like the notebook message = { "role": "user", "content": [ {"type": "audio", "audio": audio_path}, # Use path instead of array {"type": "text", "text": text_prompt} ] } # Process with Gemma 3n processor using the notebook approach inputs = self.processor.apply_chat_template( [message], # History is a list add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) return inputs.to(self.device, dtype=self.model.dtype) except Exception as e: print(f"Error creating multimodal input: {e}") traceback.print_exc() raise def generate_response(self, audio_path, max_new_tokens=256): """Generate text response from audio input using notebook approach""" try: # Create multimodal input inputs = self.create_multimodal_input(audio_path) input_len = inputs["input_ids"].shape[-1] # Generate response exactly like the notebook with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, disable_compile=True ) # Decode response exactly like the notebook text = self.processor.batch_decode( outputs[:, input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=True ) return text[0].strip() if text else "No response generated" except Exception as e: print(f"Error generating response: {e}") traceback.print_exc() return f"Error: {str(e)}" def stream_response(self, audio_path, max_new_tokens=512, temperature=0.9): """Generate streaming text response from audio input""" try: # Create multimodal input inputs = self.create_multimodal_input(audio_path) # Generate streaming response with torch.no_grad(): # Use the model's generate method with streaming streamer = self.processor.tokenizer outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=self.processor.tokenizer.eos_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=True ) # Decode the full response response = self.processor.tokenizer.decode( outputs.sequences[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ) return response.strip() except Exception as e: print(f"Error in streaming response: {e}") traceback.print_exc() return f"Error: {str(e)}" def text_to_speech_simple(self, text): """Convert text to speech using gTTS""" try: from gtts import gTTS # Create TTS object tts = gTTS(text=text, lang='en', slow=False) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file: tts.save(tmp_file.name) # Convert MP3 to WAV format that the system expects audio_segment = AudioSegment.from_mp3(tmp_file.name) # Convert to expected format (24kHz, mono, 16-bit) audio_segment = audio_segment.set_frame_rate(24000) audio_segment = audio_segment.set_channels(1) audio_segment = audio_segment.set_sample_width(2) # Export to WAV bytes audio_buffer = io.BytesIO() audio_segment.export(audio_buffer, format="wav") # Clean up temp file os.unlink(tmp_file.name) return audio_buffer.getvalue() except ImportError: print("gTTS not available, falling back to silence") # Fallback to silence if gTTS not installed duration_seconds = max(1, len(text) / 20) sample_rate = 24000 samples = int(duration_seconds * sample_rate) audio_data = np.zeros(samples, dtype=np.int16) audio_segment = AudioSegment( audio_data.tobytes(), frame_rate=sample_rate, sample_width=2, channels=1 ) audio_buffer = io.BytesIO() audio_segment.export(audio_buffer, format="wav") return audio_buffer.getvalue() except Exception as e: print(f"Error in TTS: {e}") # Return minimal audio data on error return b'\x00' * 1024 def process_audio_stream(self, audio_bytes): """Process audio stream and return response audio stream""" try: # Decode base64 audio audio_data = base64.b64decode(audio_bytes) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_data) temp_audio_path = f.name try: # Generate text response text_response = self.generate_response(temp_audio_path) print(f"Generated response: {text_response}") # Convert to speech (placeholder) audio_response = self.text_to_speech_simple(text_response) return audio_response finally: # Clean up temp file if os.path.exists(temp_audio_path): os.unlink(temp_audio_path) except Exception as e: print(f"Error processing audio stream: {e}") traceback.print_exc() # Return minimal audio data on error return b'\x00' * 1024 def warm_up(self): """Warm up the model""" try: print("Warming up Gemma 3n model...") # Create a short dummy audio dummy_audio = np.zeros(16000, dtype=np.float32) # 1 second of silence with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: # Save dummy audio import soundfile as sf sf.write(f.name, dummy_audio, 16000) # Generate a quick response response = self.generate_response(f.name, max_new_tokens=10) print(f"Warm-up response: {response}") # Clean up os.unlink(f.name) print("Gemma 3n warm-up complete") except Exception as e: print(f"Error during warm-up: {e}") # Don't fail startup on warm-up errors pass