Spaces:
Paused
Paused
File size: 11,285 Bytes
ee78b3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
import os
import tempfile
import torch
import librosa
import numpy as np
from transformers import AutoModelForImageTextToText, AutoProcessor
from huggingface_hub import login
import io
from pydub import AudioSegment
import base64
import traceback
class Gemma3nInference:
def __init__(self, device='cuda:0'):
self.device = device
# Login to Hugging Face using token from environment
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
else:
print("Warning: HF_TOKEN not found in environment variables")
print("Loading Gemma 3n model...")
try:
# Try loading Gemma 3n E2B (2B effective params) using the correct class
model_name = "google/gemma-3n-E2B-it"
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype="auto", # Let it auto-detect the best dtype
device_map="auto",
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name)
print(f"Gemma 3n E2B model loaded successfully on device: {self.model.device}")
print(f"Model dtype: {self.model.dtype}")
except Exception as e:
print(f"Error loading Gemma 3n model: {e}")
print("Trying alternative loading method...")
try:
# Try loading without vision components initially
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True
)
# Disable vision tower if causing issues
if hasattr(config, 'vision_config'):
print("Attempting to load without problematic vision config...")
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype="auto",
trust_remote_code=True,
ignore_mismatched_sizes=True
).to(self.device)
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print("Gemma 3n E2B model loaded with alternative method")
except Exception as e2:
print(f"Alternative loading also failed: {e2}")
raise e2
def preprocess_audio(self, audio_path):
"""Convert audio to Gemma 3n format: 16kHz mono float32 in range [-1, 1]"""
try:
# Load audio file
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
# Ensure audio is in range [-1, 1]
if audio.max() > 1.0 or audio.min() < -1.0:
audio = audio / max(abs(audio.max()), abs(audio.min()))
# Limit to 30 seconds as recommended
max_samples = 30 * 16000
if len(audio) > max_samples:
audio = audio[:max_samples]
return audio.astype(np.float32)
except Exception as e:
print(f"Error preprocessing audio: {e}")
raise
def create_multimodal_input(self, audio_path, text_prompt="Respond naturally to this audio input"):
"""Create multimodal input for Gemma 3n using the same format as the notebook"""
try:
# Preprocess audio
audio_array = self.preprocess_audio(audio_path)
# Create multimodal message format exactly like the notebook
message = {
"role": "user",
"content": [
{"type": "audio", "audio": audio_path}, # Use path instead of array
{"type": "text", "text": text_prompt}
]
}
# Process with Gemma 3n processor using the notebook approach
inputs = self.processor.apply_chat_template(
[message], # History is a list
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
return inputs.to(self.device, dtype=self.model.dtype)
except Exception as e:
print(f"Error creating multimodal input: {e}")
traceback.print_exc()
raise
def generate_response(self, audio_path, max_new_tokens=256):
"""Generate text response from audio input using notebook approach"""
try:
# Create multimodal input
inputs = self.create_multimodal_input(audio_path)
input_len = inputs["input_ids"].shape[-1]
# Generate response exactly like the notebook
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
disable_compile=True
)
# Decode response exactly like the notebook
text = self.processor.batch_decode(
outputs[:, input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return text[0].strip() if text else "No response generated"
except Exception as e:
print(f"Error generating response: {e}")
traceback.print_exc()
return f"Error: {str(e)}"
def stream_response(self, audio_path, max_new_tokens=512, temperature=0.9):
"""Generate streaming text response from audio input"""
try:
# Create multimodal input
inputs = self.create_multimodal_input(audio_path)
# Generate streaming response
with torch.no_grad():
# Use the model's generate method with streaming
streamer = self.processor.tokenizer
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True
)
# Decode the full response
response = self.processor.tokenizer.decode(
outputs.sequences[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
print(f"Error in streaming response: {e}")
traceback.print_exc()
return f"Error: {str(e)}"
def text_to_speech_simple(self, text):
"""Convert text to speech using gTTS"""
try:
from gtts import gTTS
# Create TTS object
tts = gTTS(text=text, lang='en', slow=False)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
tts.save(tmp_file.name)
# Convert MP3 to WAV format that the system expects
audio_segment = AudioSegment.from_mp3(tmp_file.name)
# Convert to expected format (24kHz, mono, 16-bit)
audio_segment = audio_segment.set_frame_rate(24000)
audio_segment = audio_segment.set_channels(1)
audio_segment = audio_segment.set_sample_width(2)
# Export to WAV bytes
audio_buffer = io.BytesIO()
audio_segment.export(audio_buffer, format="wav")
# Clean up temp file
os.unlink(tmp_file.name)
return audio_buffer.getvalue()
except ImportError:
print("gTTS not available, falling back to silence")
# Fallback to silence if gTTS not installed
duration_seconds = max(1, len(text) / 20)
sample_rate = 24000
samples = int(duration_seconds * sample_rate)
audio_data = np.zeros(samples, dtype=np.int16)
audio_segment = AudioSegment(
audio_data.tobytes(),
frame_rate=sample_rate,
sample_width=2,
channels=1
)
audio_buffer = io.BytesIO()
audio_segment.export(audio_buffer, format="wav")
return audio_buffer.getvalue()
except Exception as e:
print(f"Error in TTS: {e}")
# Return minimal audio data on error
return b'\x00' * 1024
def process_audio_stream(self, audio_bytes):
"""Process audio stream and return response audio stream"""
try:
# Decode base64 audio
audio_data = base64.b64decode(audio_bytes)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_data)
temp_audio_path = f.name
try:
# Generate text response
text_response = self.generate_response(temp_audio_path)
print(f"Generated response: {text_response}")
# Convert to speech (placeholder)
audio_response = self.text_to_speech_simple(text_response)
return audio_response
finally:
# Clean up temp file
if os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
except Exception as e:
print(f"Error processing audio stream: {e}")
traceback.print_exc()
# Return minimal audio data on error
return b'\x00' * 1024
def warm_up(self):
"""Warm up the model"""
try:
print("Warming up Gemma 3n model...")
# Create a short dummy audio
dummy_audio = np.zeros(16000, dtype=np.float32) # 1 second of silence
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
# Save dummy audio
import soundfile as sf
sf.write(f.name, dummy_audio, 16000)
# Generate a quick response
response = self.generate_response(f.name, max_new_tokens=10)
print(f"Warm-up response: {response}")
# Clean up
os.unlink(f.name)
print("Gemma 3n warm-up complete")
except Exception as e:
print(f"Error during warm-up: {e}")
# Don't fail startup on warm-up errors
pass |