File size: 2,200 Bytes
06820f1
 
 
5498fb6
06820f1
5498fb6
5738222
06820f1
 
 
 
 
 
5738222
06820f1
 
 
 
 
 
 
 
 
 
 
debbc88
06820f1
 
 
 
 
 
 
 
 
debbc88
06820f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
debbc88
 
 
06820f1
debbc88
 
06820f1
5738222
debbc88
 
5498fb6
 
 
debbc88
5498fb6
 
5738222
4655c56
c7ad7f6
4655c56
5738222
 
debbc88
5498fb6
 
 
06820f1
5498fb6
06820f1
5498fb6
 
debbc88
5498fb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import time

import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastrtc import (
    ReplyOnPause,
    Stream,
    get_stt_model,
    get_tts_model,
)
from gradio.utils import get_space
from numpy.typing import NDArray
from openai import OpenAI

load_dotenv()

sambanova_client = OpenAI(
    api_key=os.getenv("SAMBANOVA_API_KEY"), base_url="https://api.sambanova.ai/v1"
)

stt_model = get_stt_model()
tts_model = get_tts_model()

chat_history = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant having a spoken conversation."
            "Please keep your answers short and concise."
        ),
    }
]


def echo(audio: tuple[int, NDArray[np.int16]]):
    prompt = stt_model.stt(audio)
    print("prompt", prompt)
    chat_history.append({"role": "user", "content": prompt})
    start_time = time.time()
    response = sambanova_client.chat.completions.create(
        model="Meta-Llama-3.2-3B-Instruct",
        messages=chat_history,
        max_tokens=200,
    )
    end_time = time.time()
    print("time taken inference", end_time - start_time)
    prompt = response.choices[0].message.content
    chat_history.append({"role": "assistant", "content": prompt})
    start_time = time.time()
    for audio_chunk in tts_model.stream_tts_sync(prompt):
        yield audio_chunk
    end_time = time.time()
    print("time taken tts", end_time - start_time)


stream = Stream(
    handler=ReplyOnPause(echo),
    modality="audio",
    mode="send-receive",
    rtc_configuration=None,  # get_twilio_turn_credentials() if get_space() else None,
    concurrency_limit=20 if get_space() else None,
)

app = FastAPI()

stream.mount(app)


@app.get("/")
async def index():
    return RedirectResponse(
        url="/ui" if not get_space() else "https://fastrtc-echo-audio.hf.space/ui/"
    )


if __name__ == "__main__":
    import os

    if (mode := os.getenv("MODE")) == "UI":
        stream.ui.launch(server_port=7860)
    elif mode == "PHONE":
        stream.fastphone(port=7860)
    else:
        import uvicorn

        uvicorn.run(app, host="0.0.0.0", port=7860)