File size: 2,976 Bytes
d83f09b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
FastRTC + Gemma-3 minimal voice chat app
Requirements:
    pip install fastrtc transformers torch torchaudio
"""

import asyncio
from typing import AsyncGenerator

from fastrtc import (
    ReplyOnPause,
    Stream,
    get_stt_model,
    get_tts_model,
    wait_for_item,
)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# ------------------------------------------------------------------
# 1.  Load Gemma-3 (4b-it) via transformers
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# ------------------------------------------------------------------
# 2.  Build a simple chat pipeline
# ------------------------------------------------------------------
chat_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
)

# ------------------------------------------------------------------
# 3.  Voice pipeline helpers
# ------------------------------------------------------------------
stt = get_stt_model("tiny")
tts = get_tts_model("coqui/XTTS-v2", lang="en")


# ------------------------------------------------------------------
# 4.  Response generator
# ------------------------------------------------------------------
def response_generator(prompt: str) -> str:
    """Feed the user prompt to Gemma-3 and return the assistant text."""
    messages = [{"role": "user", "content": prompt}]
    prompt_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    output = chat_pipeline(prompt_text)[0]["generated_text"]
    # strip the prompt from the output
    return output[len(prompt_text) :].strip()


# ------------------------------------------------------------------
# 5.  FastRTC streaming handler
# ------------------------------------------------------------------
async def chat_handler(
    audio: AsyncGenerator,
) -> AsyncGenerator[bytes, None]:
    """Receive user voice, transcribe, answer via Gemma-3, stream back TTS audio."""
    async for user_text in stt.transcribe(audio):
        if not user_text.strip():
            continue

        # Generate response
        reply_text = response_generator(user_text)

        # Stream TTS audio back to the user
        async for chunk in tts.synthesize(reply_text):
            yield chunk


# ------------------------------------------------------------------
# 6.  Launch the app
# ------------------------------------------------------------------
if __name__ == "__main__":
    stream = Stream(
        handler=ReplyOnPause(chat_handler),
        modality="audio",
        mode="send-receive",
    )
    stream.ui.launch(server_name="0.0.0.0", server_port=7860)