freddyaboulton's picture
Update app.py
872589f verified
raw
history blame
3.34 kB
import gradio as gr
import numpy as np
import torch
from dotenv import load_dotenv
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
WebRTCError,
audio_to_float32,
get_current_context,
get_hf_turn_credentials,
get_hf_turn_credentials_async,
get_stt_model,
get_tts_model,
)
from huggingface_hub import InferenceClient
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import spaces
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
load_dotenv()
stt_model = get_stt_model()
tts_model = get_tts_model()
conversations: dict[str, list[dict[str, str]]] = {}
@spaces.GPU
def response(
audio: tuple[int, np.ndarray],
hf_token: str | None,
):
if hf_token is None or hf_token == "":
raise WebRTCError("HF Token is required")
llm_client = InferenceClient(provider="auto", token=hf_token)
result = pipe(
{"array": audio_to_float32(audio[1]).squeeze(), "sampling_rate": audio[0]},
generate_kwargs={"language": "en"},
)
transcription = result["text"]
context = get_current_context()
if context.webrtc_id not in conversations:
conversations[context.webrtc_id] = [
{
"role": "system",
"content": (
"You are a helpful assistant that can have engaging conversations."
"Your responses must be very short and concise. No more than two sentences. "
"Reasoning: low"
),
}
]
messages = conversations[context.webrtc_id]
messages.append({"role": "user", "content": transcription})
output = llm_client.chat.completions.create( # type: ignore
model="openai/gpt-oss-20b",
messages=messages, # type: ignore
max_tokens=1024,
stream=True,
)
output_text = ""
for chunk in output:
output_text += chunk.choices[0].delta.content or ""
messages.append({"role": "assistant", "content": output_text})
conversations[context.webrtc_id] = messages
yield from tts_model.stream_tts_sync(output_text)
yield AdditionalOutputs(messages)
chatbot = gr.Chatbot(label="Chatbot", type="messages")
token = gr.Textbox(
label="HF Token",
value="",
type="password",
)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
server_rtc_configuration=get_hf_turn_credentials(),
rtc_configuration=get_hf_turn_credentials_async,
additional_inputs=[token],
additional_outputs=[chatbot],
additional_outputs_handler=lambda old, new: new,
ui_args={"title": "Talk To OpenAI GPT-OSS 20B (Powered by FastRTC ⚡️)"},
time_limit=90,
concurrency_limit=5,
)
stream.ui.launch()