Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import nltk
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
from __future__ import annotations
|
9 |
+
|
10 |
+
os.environ["COQUI_TOS_AGREED"] = "1"
|
11 |
+
nltk.download('punkt')
|
12 |
+
from TTS.api import TTS
|
13 |
+
from huggingface_hub import HfApi
|
14 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
15 |
+
|
16 |
+
|
17 |
+
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
|
18 |
+
title = "Voice Chat Mistral"
|
19 |
+
DESCRIPTION = title
|
20 |
+
css = """.toast-wrap { display: none !important } """
|
21 |
+
api = HfApi(token=HF_TOKEN)
|
22 |
+
repo_id = "ylacombe/voice-chat-with-lama"
|
23 |
+
system_message = "\nYou are a helpful assistant."
|
24 |
+
temperature = 0.9
|
25 |
+
top_p = 0.6
|
26 |
+
repetition_penalty = 1.2
|
27 |
+
|
28 |
+
import gradio as gr
|
29 |
+
import os
|
30 |
+
import time
|
31 |
+
|
32 |
+
import gradio as gr
|
33 |
+
from transformers import pipeline
|
34 |
+
import numpy as np
|
35 |
+
|
36 |
+
from gradio_client import Client
|
37 |
+
from huggingface_hub import InferenceClient
|
38 |
+
|
39 |
+
|
40 |
+
whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
|
41 |
+
text_client = InferenceClient(
|
42 |
+
"mistralai/Mistral-7B-Instruct-v0.1"
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def format_prompt(message, history):
|
47 |
+
prompt = "<s>"
|
48 |
+
for user_prompt, bot_response in history:
|
49 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
50 |
+
prompt += f" {bot_response}</s> "
|
51 |
+
prompt += f"[INST] {message} [/INST]"
|
52 |
+
return prompt
|
53 |
+
|
54 |
+
def generate(
|
55 |
+
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
|
56 |
+
):
|
57 |
+
temperature = float(temperature)
|
58 |
+
if temperature < 1e-2:
|
59 |
+
temperature = 1e-2
|
60 |
+
top_p = float(top_p)
|
61 |
+
|
62 |
+
generate_kwargs = dict(
|
63 |
+
temperature=temperature,
|
64 |
+
max_new_tokens=max_new_tokens,
|
65 |
+
top_p=top_p,
|
66 |
+
repetition_penalty=repetition_penalty,
|
67 |
+
do_sample=True,
|
68 |
+
seed=42,
|
69 |
+
)
|
70 |
+
|
71 |
+
formatted_prompt = format_prompt(prompt, history)
|
72 |
+
|
73 |
+
stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
74 |
+
output = ""
|
75 |
+
|
76 |
+
for response in stream:
|
77 |
+
output += response.token.text
|
78 |
+
yield output
|
79 |
+
return output
|
80 |
+
|
81 |
+
|
82 |
+
def transcribe(wav_path):
|
83 |
+
|
84 |
+
return whisper_client.predict(
|
85 |
+
wav_path, # str (filepath or URL to file) in 'inputs' Audio component
|
86 |
+
"transcribe", # str in 'Task' Radio component
|
87 |
+
api_name="/predict"
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
|
92 |
+
|
93 |
+
|
94 |
+
def add_text(history, text):
|
95 |
+
history = [] if history is None else history
|
96 |
+
history = history + [(text, None)]
|
97 |
+
return history, gr.update(value="", interactive=False)
|
98 |
+
|
99 |
+
|
100 |
+
def add_file(history, file):
|
101 |
+
history = [] if history is None else history
|
102 |
+
text = transcribe(
|
103 |
+
file
|
104 |
+
)
|
105 |
+
|
106 |
+
history = history + [(text, None)]
|
107 |
+
return history
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def bot(history, system_prompt=""):
|
112 |
+
history = [] if history is None else history
|
113 |
+
|
114 |
+
if system_prompt == "":
|
115 |
+
system_prompt = system_message
|
116 |
+
|
117 |
+
history[-1][1] = ""
|
118 |
+
for character in generate(history[-1][0], history[:-1]):
|
119 |
+
history[-1][1] = character
|
120 |
+
yield history
|
121 |
+
|
122 |
+
|
123 |
+
def generate_speech(history):
|
124 |
+
text_to_generate = history[-1][1]
|
125 |
+
text_to_generate = text_to_generate.replace("\n", " ").strip()
|
126 |
+
text_to_generate = nltk.sent_tokenize(text_to_generate)
|
127 |
+
|
128 |
+
filename = f"{uuid.uuid4()}.wav"
|
129 |
+
sampling_rate = tts.synthesizer.tts_config.audio["sample_rate"]
|
130 |
+
silence = [0] * int(0.25 * sampling_rate)
|
131 |
+
|
132 |
+
|
133 |
+
for sentence in text_to_generate:
|
134 |
+
try:
|
135 |
+
|
136 |
+
# generate speech by cloning a voice using default settings
|
137 |
+
wav = tts.tts(text=sentence,
|
138 |
+
speaker_wav="examples/female.wav",
|
139 |
+
decoder_iterations=25,
|
140 |
+
decoder_sampler="dpm++2m",
|
141 |
+
speed=1.2,
|
142 |
+
language="en")
|
143 |
+
|
144 |
+
yield (sampling_rate, np.array(wav)) #np.array(wav + silence))
|
145 |
+
|
146 |
+
except RuntimeError as e :
|
147 |
+
if "device-side assert" in str(e):
|
148 |
+
# cannot do anything on cuda device side error, need tor estart
|
149 |
+
print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True)
|
150 |
+
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
151 |
+
print("Cuda device-assert Runtime encountered need restart")
|
152 |
+
|
153 |
+
|
154 |
+
# HF Space specific.. This error is unrecoverable need to restart space
|
155 |
+
api.restart_space(repo_id=repo_id)
|
156 |
+
else:
|
157 |
+
print("RuntimeError: non device-side assert error:", str(e))
|
158 |
+
raise e
|
159 |
+
|
160 |
+
with gr.Blocks(title=title) as demo:
|
161 |
+
gr.Markdown(DESCRIPTION)
|
162 |
+
|
163 |
+
|
164 |
+
chatbot = gr.Chatbot(
|
165 |
+
[],
|
166 |
+
elem_id="chatbot",
|
167 |
+
avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),
|
168 |
+
bubble_full_width=False,
|
169 |
+
)
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
txt = gr.Textbox(
|
173 |
+
scale=3,
|
174 |
+
show_label=False,
|
175 |
+
placeholder="Enter text and press enter, or speak to your microphone",
|
176 |
+
container=False,
|
177 |
+
)
|
178 |
+
txt_btn = gr.Button(value="Submit text",scale=1)
|
179 |
+
btn = gr.Audio(source="microphone", type="filepath", scale=4)
|
180 |
+
|
181 |
+
with gr.Row():
|
182 |
+
audio = gr.Audio(type="numpy", streaming=True, autoplay=True, label="Generated audio response", show_label=True)
|
183 |
+
|
184 |
+
clear_btn = gr.ClearButton([chatbot, audio])
|
185 |
+
|
186 |
+
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
187 |
+
bot, chatbot, chatbot
|
188 |
+
).then(generate_speech, chatbot, audio)
|
189 |
+
|
190 |
+
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
|
191 |
+
|
192 |
+
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
193 |
+
bot, chatbot, chatbot
|
194 |
+
).then(generate_speech, chatbot, audio)
|
195 |
+
|
196 |
+
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
|
197 |
+
|
198 |
+
file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then(
|
199 |
+
bot, chatbot, chatbot
|
200 |
+
).then(generate_speech, chatbot, audio)
|
201 |
+
|
202 |
+
|
203 |
+
gr.Markdown("""
|
204 |
+
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
|
205 |
+
It relies on 3 models:
|
206 |
+
1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-large-v2) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
|
207 |
+
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
|
208 |
+
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
|
209 |
+
Note:
|
210 |
+
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""")
|
211 |
+
demo.queue()
|
212 |
+
demo.launch(debug=True)
|