File size: 3,341 Bytes
22d4d93
ae054df
 
22d4d93
 
ae054df
22d4d93
 
ae054df
 
22d4d93
 
 
 
 
 
 
ae054df
872589f
ae054df
 
 
 
 
 
 
6321d71
 
 
 
ae054df
 
 
 
 
 
 
 
 
 
 
 
 
 
22d4d93
 
 
 
 
 
 
 
872589f
22d4d93
 
 
 
6c44434
22d4d93
6c44434
ae054df
 
 
6321d71
 
ae054df
 
6c44434
22d4d93
 
 
 
 
 
 
 
6321d71
22d4d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae054df
 
22d4d93
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import numpy as np
import torch
from dotenv import load_dotenv
from fastrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    Stream,
    WebRTCError,
    audio_to_float32,
    get_current_context,
    get_hf_turn_credentials,
    get_hf_turn_credentials_async,
    get_stt_model,
    get_tts_model,
)
from huggingface_hub import InferenceClient
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import spaces

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3-turbo"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)


load_dotenv()

stt_model = get_stt_model()
tts_model = get_tts_model()

conversations: dict[str, list[dict[str, str]]] = {}

@spaces.GPU
def response(
    audio: tuple[int, np.ndarray],
    hf_token: str | None,
):
    if hf_token is None or hf_token == "":
        raise WebRTCError("HF Token is required")

    llm_client = InferenceClient(provider="auto", token=hf_token)

    result = pipe(
        {"array": audio_to_float32(audio[1]).squeeze(), "sampling_rate": audio[0]},
        generate_kwargs={"language": "en"},
    )
    transcription = result["text"]

    context = get_current_context()
    if context.webrtc_id not in conversations:
        conversations[context.webrtc_id] = [
            {
                "role": "system",
                "content": (
                    "You are a helpful assistant that can have engaging conversations."
                    "Your responses must be very short and concise. No more than two sentences. "
                    "Reasoning: low"
                ),
            }
        ]

    messages = conversations[context.webrtc_id]

    messages.append({"role": "user", "content": transcription})

    output = llm_client.chat.completions.create(  # type: ignore
        model="openai/gpt-oss-20b",
        messages=messages,  # type: ignore
        max_tokens=1024,
        stream=True,
    )

    output_text = ""
    for chunk in output:
        output_text += chunk.choices[0].delta.content or ""

    messages.append({"role": "assistant", "content": output_text})
    conversations[context.webrtc_id] = messages
    yield from tts_model.stream_tts_sync(output_text)
    yield AdditionalOutputs(messages)


chatbot = gr.Chatbot(label="Chatbot", type="messages")
token = gr.Textbox(
    label="HF Token",
    value="",
    type="password",
)
stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response),
    server_rtc_configuration=get_hf_turn_credentials(),
    rtc_configuration=get_hf_turn_credentials_async,
    additional_inputs=[token],
    additional_outputs=[chatbot],
    additional_outputs_handler=lambda old, new: new,
    ui_args={"title": "Talk To OpenAI GPT-OSS 20B (Powered by FastRTC ⚡️)"},
    time_limit=90,
    concurrency_limit=5,
)

stream.ui.launch()