File size: 2,200 Bytes
606dee0
 
 
6ae9e35
606dee0
6ae9e35
 
606dee0
 
 
 
 
 
6ae9e35
606dee0
 
 
 
 
 
 
 
 
 
 
6ae9e35
606dee0
 
 
 
 
 
 
 
 
6ae9e35
606dee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae9e35
 
 
606dee0
6ae9e35
 
606dee0
6ae9e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606dee0
6ae9e35
606dee0
6ae9e35
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import time

import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastrtc import (
    ReplyOnPause,
    Stream,
    get_stt_model,
    get_tts_model,
)
from gradio.utils import get_space
from numpy.typing import NDArray
from openai import OpenAI

load_dotenv()

sambanova_client = OpenAI(
    api_key=os.getenv("SAMBANOVA_API_KEY"), base_url="https://api.sambanova.ai/v1"
)

stt_model = get_stt_model()
tts_model = get_tts_model()

chat_history = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant having a spoken conversation."
            "Please keep your answers short and concise."
        ),
    }
]


def echo(audio: tuple[int, NDArray[np.int16]]):
    prompt = stt_model.stt(audio)
    print("prompt", prompt)
    chat_history.append({"role": "user", "content": prompt})
    start_time = time.time()
    response = sambanova_client.chat.completions.create(
        model="Meta-Llama-3.2-3B-Instruct",
        messages=chat_history,
        max_tokens=200,
    )
    end_time = time.time()
    print("time taken inference", end_time - start_time)
    prompt = response.choices[0].message.content
    chat_history.append({"role": "assistant", "content": prompt})
    start_time = time.time()
    for audio_chunk in tts_model.stream_tts_sync(prompt):
        yield audio_chunk
    end_time = time.time()
    print("time taken tts", end_time - start_time)


stream = Stream(
    handler=ReplyOnPause(echo),
    modality="audio",
    mode="send-receive",
    rtc_configuration=None,  # get_twilio_turn_credentials() if get_space() else None,
    concurrency_limit=20 if get_space() else None,
)

app = FastAPI()

stream.mount(app)


@app.get("/")
async def index():
    return RedirectResponse(
        url="/ui" if not get_space() else "https://fastrtc-echo-audio.hf.space/ui/"
    )


if __name__ == "__main__":
    import os

    if (mode := os.getenv("MODE")) == "UI":
        stream.ui.launch(server_port=7860)
    elif mode == "PHONE":
        stream.fastphone(port=7860)
    else:
        import uvicorn

        uvicorn.run(app, host="0.0.0.0", port=7860)