Spaces:
Sleeping
Sleeping
from fastrtc import ( | |
ReplyOnPause, AdditionalOutputs, Stream, | |
audio_to_bytes, aggregate_bytes_to_16bit | |
) | |
import gradio as gr | |
import numpy as np | |
import torch | |
import os | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
pipeline, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM | |
) | |
from datasets import load_dataset | |
import scipy | |
# Check if CUDA is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Step 1: Audio transcription with Whisper | |
def load_asr_model(): | |
model_id = "openai/whisper-small" # Smaller version that's more efficient | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps=False, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Step 2: Text generation with a smaller LLM | |
def load_llm_model(): | |
model_id = "facebook/opt-1.3b" # A smaller language model | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True | |
) | |
model.to(device) | |
return model, tokenizer | |
# Step 3: Text-to-Speech with a free model | |
def load_tts_model(): | |
model_id = "microsoft/speecht5_tts" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
model.to(device) | |
# Load vocoder for waveform generation | |
vocoder_id = "microsoft/speecht5_hifigan" | |
vocoder = AutoModelForCausalLM.from_pretrained(vocoder_id) | |
vocoder.to(device) | |
# Load speaker embeddings | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
speaker_embeddings = torch.tensor(embeddings_dataset[7]["xvector"]).unsqueeze(0) | |
return model, processor, vocoder, speaker_embeddings | |
# Initialize all models | |
print("Loading ASR model...") | |
asr_pipeline = load_asr_model() | |
print("Loading LLM model...") | |
llm_model, llm_tokenizer = load_llm_model() | |
print("Loading TTS model...") | |
tts_model, tts_processor, tts_vocoder, speaker_embeddings = load_tts_model() | |
# Chat history management | |
chat_history = [] | |
def generate_response(prompt): | |
# If chat history is empty, add a system message | |
if not chat_history: | |
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."}) | |
# Add user message to history | |
chat_history.append({"role": "user", "content": prompt}) | |
# Prepare input for the model | |
full_prompt = "" | |
for message in chat_history: | |
if message["role"] == "system": | |
full_prompt += f"System: {message['content']}\n" | |
elif message["role"] == "user": | |
full_prompt += f"User: {message['content']}\n" | |
elif message["role"] == "assistant": | |
full_prompt += f"Assistant: {message['content']}\n" | |
full_prompt += "Assistant: " | |
# Generate response | |
inputs = llm_tokenizer(full_prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = llm_model.generate( | |
**inputs, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
response_text = response_text.split("Assistant: ")[-1].strip() | |
# Add assistant response to history | |
chat_history.append({"role": "assistant", "content": response_text}) | |
# Keep history at a reasonable size | |
if len(chat_history) > 10: | |
# Keep system message and last 9 exchanges | |
chat_history.pop(1) | |
return response_text | |
def text_to_speech(text): | |
# Prepare inputs | |
inputs = tts_processor(text=text, return_tensors="pt") | |
# Add speaker embeddings | |
inputs["speaker_embeddings"] = speaker_embeddings.to(device) | |
# Generate speech | |
with torch.no_grad(): | |
speech = tts_model.generate_speech( | |
inputs["input_ids"].to(device), | |
speaker_embeddings.to(device) | |
) | |
# Convert to waveform using vocoder | |
with torch.no_grad(): | |
waveform = tts_vocoder(speech) | |
# Convert to numpy array | |
audio_array = waveform.cpu().numpy().squeeze() | |
# Normalize and convert to int16 | |
audio_array = (audio_array / np.max(np.abs(audio_array)) * 32767).astype(np.int16) | |
# Reshape for fastrtc | |
audio_array = audio_array.reshape(1, -1) | |
return (24000, audio_array) # Using 24kHz sample rate | |
def response(audio: tuple[int, np.ndarray]): | |
# Step 1: Speech-to-Text | |
transcript = asr_pipeline({"sampling_rate": audio[0], "raw": audio[1].flatten()}) | |
prompt = transcript["text"] | |
# Step 2: Generate text response | |
response_text = generate_response(prompt) | |
# Step 3: Text-to-Speech | |
sample_rate, audio_array = text_to_speech(response_text) | |
# Convert to expected format | |
chunk_size = 4800 # 200ms chunks at 24kHz | |
for i in range(0, audio_array.shape[1], chunk_size): | |
chunk = audio_array[:, i:i+chunk_size] | |
if chunk.size > 0: # Ensure we don't yield empty chunks | |
yield (sample_rate, chunk) | |
stream = Stream( | |
modality="audio", | |
mode="send-receive", | |
handler=ReplyOnPause(response), | |
) | |
# For testing without WebRTC | |
def demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Local Voice Chatbot") | |
audio_input = gr.Audio(sources=["microphone"], type="numpy") | |
audio_output = gr.Audio() | |
def process_audio(audio): | |
if audio is None: | |
return None | |
sample_rate, audio_array = audio | |
transcript = asr_pipeline({"sampling_rate": sample_rate, "raw": audio_array.flatten()}) | |
prompt = transcript["text"] | |
print(f"Transcribed: {prompt}") | |
response_text = generate_response(prompt) | |
print(f"Response: {response_text}") | |
sample_rate, audio_array = text_to_speech(response_text) | |
return (sample_rate, audio_array[0]) | |
audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output]) | |
demo.launch() | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC") | |
args = parser.parse_args() | |
if args.demo: | |
demo() | |
else: | |
# For running with FastRTC | |
# You would need to add your FastRTC server code here | |
pass |