Spaces:
Runtime error
Runtime error
File size: 2,976 Bytes
d83f09b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""
FastRTC + Gemma-3 minimal voice chat app
Requirements:
pip install fastrtc transformers torch torchaudio
"""
import asyncio
from typing import AsyncGenerator
from fastrtc import (
ReplyOnPause,
Stream,
get_stt_model,
get_tts_model,
wait_for_item,
)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# ------------------------------------------------------------------
# 1. Load Gemma-3 (4b-it) via transformers
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# ------------------------------------------------------------------
# 2. Build a simple chat pipeline
# ------------------------------------------------------------------
chat_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
)
# ------------------------------------------------------------------
# 3. Voice pipeline helpers
# ------------------------------------------------------------------
stt = get_stt_model("tiny")
tts = get_tts_model("coqui/XTTS-v2", lang="en")
# ------------------------------------------------------------------
# 4. Response generator
# ------------------------------------------------------------------
def response_generator(prompt: str) -> str:
"""Feed the user prompt to Gemma-3 and return the assistant text."""
messages = [{"role": "user", "content": prompt}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
output = chat_pipeline(prompt_text)[0]["generated_text"]
# strip the prompt from the output
return output[len(prompt_text) :].strip()
# ------------------------------------------------------------------
# 5. FastRTC streaming handler
# ------------------------------------------------------------------
async def chat_handler(
audio: AsyncGenerator,
) -> AsyncGenerator[bytes, None]:
"""Receive user voice, transcribe, answer via Gemma-3, stream back TTS audio."""
async for user_text in stt.transcribe(audio):
if not user_text.strip():
continue
# Generate response
reply_text = response_generator(user_text)
# Stream TTS audio back to the user
async for chunk in tts.synthesize(reply_text):
yield chunk
# ------------------------------------------------------------------
# 6. Launch the app
# ------------------------------------------------------------------
if __name__ == "__main__":
stream = Stream(
handler=ReplyOnPause(chat_handler),
modality="audio",
mode="send-receive",
)
stream.ui.launch(server_name="0.0.0.0", server_port=7860) |