Test_Voice / voice_chat /gemma3n_inference.py
raksa-the-wildcats
Add all project files with proper LFS tracking
ee78b3d
import os
import tempfile
import torch
import librosa
import numpy as np
from transformers import AutoModelForImageTextToText, AutoProcessor
from huggingface_hub import login
import io
from pydub import AudioSegment
import base64
import traceback
class Gemma3nInference:
def __init__(self, device='cuda:0'):
self.device = device
# Login to Hugging Face using token from environment
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
else:
print("Warning: HF_TOKEN not found in environment variables")
print("Loading Gemma 3n model...")
try:
# Try loading Gemma 3n E2B (2B effective params) using the correct class
model_name = "google/gemma-3n-E2B-it"
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype="auto", # Let it auto-detect the best dtype
device_map="auto",
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name)
print(f"Gemma 3n E2B model loaded successfully on device: {self.model.device}")
print(f"Model dtype: {self.model.dtype}")
except Exception as e:
print(f"Error loading Gemma 3n model: {e}")
print("Trying alternative loading method...")
try:
# Try loading without vision components initially
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
# Disable vision tower if causing issues
if hasattr(config, 'vision_config'):
print("Attempting to load without problematic vision config...")
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype="auto",
trust_remote_code=True,
ignore_mismatched_sizes=True
).to(self.device)
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print("Gemma 3n E2B model loaded with alternative method")
except Exception as e2:
print(f"Alternative loading also failed: {e2}")
raise e2
def preprocess_audio(self, audio_path):
"""Convert audio to Gemma 3n format: 16kHz mono float32 in range [-1, 1]"""
try:
# Load audio file
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
# Ensure audio is in range [-1, 1]
if audio.max() > 1.0 or audio.min() < -1.0:
audio = audio / max(abs(audio.max()), abs(audio.min()))
# Limit to 30 seconds as recommended
max_samples = 30 * 16000
if len(audio) > max_samples:
audio = audio[:max_samples]
return audio.astype(np.float32)
except Exception as e:
print(f"Error preprocessing audio: {e}")
raise
def create_multimodal_input(self, audio_path, text_prompt="Respond naturally to this audio input"):
"""Create multimodal input for Gemma 3n using the same format as the notebook"""
try:
# Preprocess audio
audio_array = self.preprocess_audio(audio_path)
# Create multimodal message format exactly like the notebook
message = {
"role": "user",
"content": [
{"type": "audio", "audio": audio_path}, # Use path instead of array
{"type": "text", "text": text_prompt}
]
}
# Process with Gemma 3n processor using the notebook approach
inputs = self.processor.apply_chat_template(
[message], # History is a list
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
return inputs.to(self.device, dtype=self.model.dtype)
except Exception as e:
print(f"Error creating multimodal input: {e}")
traceback.print_exc()
raise
def generate_response(self, audio_path, max_new_tokens=256):
"""Generate text response from audio input using notebook approach"""
try:
# Create multimodal input
inputs = self.create_multimodal_input(audio_path)
input_len = inputs["input_ids"].shape[-1]
# Generate response exactly like the notebook
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
disable_compile=True
)
# Decode response exactly like the notebook
text = self.processor.batch_decode(
outputs[:, input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return text[0].strip() if text else "No response generated"
except Exception as e:
print(f"Error generating response: {e}")
traceback.print_exc()
return f"Error: {str(e)}"
def stream_response(self, audio_path, max_new_tokens=512, temperature=0.9):
"""Generate streaming text response from audio input"""
try:
# Create multimodal input
inputs = self.create_multimodal_input(audio_path)
# Generate streaming response
with torch.no_grad():
# Use the model's generate method with streaming
streamer = self.processor.tokenizer
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True
)
# Decode the full response
response = self.processor.tokenizer.decode(
outputs.sequences[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
print(f"Error in streaming response: {e}")
traceback.print_exc()
return f"Error: {str(e)}"
def text_to_speech_simple(self, text):
"""Convert text to speech using gTTS"""
try:
from gtts import gTTS
# Create TTS object
tts = gTTS(text=text, lang='en', slow=False)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
tts.save(tmp_file.name)
# Convert MP3 to WAV format that the system expects
audio_segment = AudioSegment.from_mp3(tmp_file.name)
# Convert to expected format (24kHz, mono, 16-bit)
audio_segment = audio_segment.set_frame_rate(24000)
audio_segment = audio_segment.set_channels(1)
audio_segment = audio_segment.set_sample_width(2)
# Export to WAV bytes
audio_buffer = io.BytesIO()
audio_segment.export(audio_buffer, format="wav")
# Clean up temp file
os.unlink(tmp_file.name)
return audio_buffer.getvalue()
except ImportError:
print("gTTS not available, falling back to silence")
# Fallback to silence if gTTS not installed
duration_seconds = max(1, len(text) / 20)
sample_rate = 24000
samples = int(duration_seconds * sample_rate)
audio_data = np.zeros(samples, dtype=np.int16)
audio_segment = AudioSegment(
audio_data.tobytes(),
frame_rate=sample_rate,
sample_width=2,
channels=1
)
audio_buffer = io.BytesIO()
audio_segment.export(audio_buffer, format="wav")
return audio_buffer.getvalue()
except Exception as e:
print(f"Error in TTS: {e}")
# Return minimal audio data on error
return b'\x00' * 1024
def process_audio_stream(self, audio_bytes):
"""Process audio stream and return response audio stream"""
try:
# Decode base64 audio
audio_data = base64.b64decode(audio_bytes)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_data)
temp_audio_path = f.name
try:
# Generate text response
text_response = self.generate_response(temp_audio_path)
print(f"Generated response: {text_response}")
# Convert to speech (placeholder)
audio_response = self.text_to_speech_simple(text_response)
return audio_response
finally:
# Clean up temp file
if os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
except Exception as e:
print(f"Error processing audio stream: {e}")
traceback.print_exc()
# Return minimal audio data on error
return b'\x00' * 1024
def warm_up(self):
"""Warm up the model"""
try:
print("Warming up Gemma 3n model...")
# Create a short dummy audio
dummy_audio = np.zeros(16000, dtype=np.float32) # 1 second of silence
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
# Save dummy audio
import soundfile as sf
sf.write(f.name, dummy_audio, 16000)
# Generate a quick response
response = self.generate_response(f.name, max_new_tokens=10)
print(f"Warm-up response: {response}")
# Clean up
os.unlink(f.name)
print("Gemma 3n warm-up complete")
except Exception as e:
print(f"Error during warm-up: {e}")
# Don't fail startup on warm-up errors
pass