awacke1 commited on
Commit
558b853
·
1 Parent(s): 0325023

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import nltk
5
+ import os
6
+ import uuid
7
+
8
+ from __future__ import annotations
9
+
10
+ os.environ["COQUI_TOS_AGREED"] = "1"
11
+ nltk.download('punkt')
12
+ from TTS.api import TTS
13
+ from huggingface_hub import HfApi
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+
17
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
18
+ title = "Voice Chat Mistral"
19
+ DESCRIPTION = title
20
+ css = """.toast-wrap { display: none !important } """
21
+ api = HfApi(token=HF_TOKEN)
22
+ repo_id = "ylacombe/voice-chat-with-lama"
23
+ system_message = "\nYou are a helpful assistant."
24
+ temperature = 0.9
25
+ top_p = 0.6
26
+ repetition_penalty = 1.2
27
+
28
+ import gradio as gr
29
+ import os
30
+ import time
31
+
32
+ import gradio as gr
33
+ from transformers import pipeline
34
+ import numpy as np
35
+
36
+ from gradio_client import Client
37
+ from huggingface_hub import InferenceClient
38
+
39
+
40
+ whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
41
+ text_client = InferenceClient(
42
+ "mistralai/Mistral-7B-Instruct-v0.1"
43
+ )
44
+
45
+
46
+ def format_prompt(message, history):
47
+ prompt = "<s>"
48
+ for user_prompt, bot_response in history:
49
+ prompt += f"[INST] {user_prompt} [/INST]"
50
+ prompt += f" {bot_response}</s> "
51
+ prompt += f"[INST] {message} [/INST]"
52
+ return prompt
53
+
54
+ def generate(
55
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
56
+ ):
57
+ temperature = float(temperature)
58
+ if temperature < 1e-2:
59
+ temperature = 1e-2
60
+ top_p = float(top_p)
61
+
62
+ generate_kwargs = dict(
63
+ temperature=temperature,
64
+ max_new_tokens=max_new_tokens,
65
+ top_p=top_p,
66
+ repetition_penalty=repetition_penalty,
67
+ do_sample=True,
68
+ seed=42,
69
+ )
70
+
71
+ formatted_prompt = format_prompt(prompt, history)
72
+
73
+ stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
74
+ output = ""
75
+
76
+ for response in stream:
77
+ output += response.token.text
78
+ yield output
79
+ return output
80
+
81
+
82
+ def transcribe(wav_path):
83
+
84
+ return whisper_client.predict(
85
+ wav_path, # str (filepath or URL to file) in 'inputs' Audio component
86
+ "transcribe", # str in 'Task' Radio component
87
+ api_name="/predict"
88
+ )
89
+
90
+
91
+ # Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
92
+
93
+
94
+ def add_text(history, text):
95
+ history = [] if history is None else history
96
+ history = history + [(text, None)]
97
+ return history, gr.update(value="", interactive=False)
98
+
99
+
100
+ def add_file(history, file):
101
+ history = [] if history is None else history
102
+ text = transcribe(
103
+ file
104
+ )
105
+
106
+ history = history + [(text, None)]
107
+ return history
108
+
109
+
110
+
111
+ def bot(history, system_prompt=""):
112
+ history = [] if history is None else history
113
+
114
+ if system_prompt == "":
115
+ system_prompt = system_message
116
+
117
+ history[-1][1] = ""
118
+ for character in generate(history[-1][0], history[:-1]):
119
+ history[-1][1] = character
120
+ yield history
121
+
122
+
123
+ def generate_speech(history):
124
+ text_to_generate = history[-1][1]
125
+ text_to_generate = text_to_generate.replace("\n", " ").strip()
126
+ text_to_generate = nltk.sent_tokenize(text_to_generate)
127
+
128
+ filename = f"{uuid.uuid4()}.wav"
129
+ sampling_rate = tts.synthesizer.tts_config.audio["sample_rate"]
130
+ silence = [0] * int(0.25 * sampling_rate)
131
+
132
+
133
+ for sentence in text_to_generate:
134
+ try:
135
+
136
+ # generate speech by cloning a voice using default settings
137
+ wav = tts.tts(text=sentence,
138
+ speaker_wav="examples/female.wav",
139
+ decoder_iterations=25,
140
+ decoder_sampler="dpm++2m",
141
+ speed=1.2,
142
+ language="en")
143
+
144
+ yield (sampling_rate, np.array(wav)) #np.array(wav + silence))
145
+
146
+ except RuntimeError as e :
147
+ if "device-side assert" in str(e):
148
+ # cannot do anything on cuda device side error, need tor estart
149
+ print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True)
150
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
151
+ print("Cuda device-assert Runtime encountered need restart")
152
+
153
+
154
+ # HF Space specific.. This error is unrecoverable need to restart space
155
+ api.restart_space(repo_id=repo_id)
156
+ else:
157
+ print("RuntimeError: non device-side assert error:", str(e))
158
+ raise e
159
+
160
+ with gr.Blocks(title=title) as demo:
161
+ gr.Markdown(DESCRIPTION)
162
+
163
+
164
+ chatbot = gr.Chatbot(
165
+ [],
166
+ elem_id="chatbot",
167
+ avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),
168
+ bubble_full_width=False,
169
+ )
170
+
171
+ with gr.Row():
172
+ txt = gr.Textbox(
173
+ scale=3,
174
+ show_label=False,
175
+ placeholder="Enter text and press enter, or speak to your microphone",
176
+ container=False,
177
+ )
178
+ txt_btn = gr.Button(value="Submit text",scale=1)
179
+ btn = gr.Audio(source="microphone", type="filepath", scale=4)
180
+
181
+ with gr.Row():
182
+ audio = gr.Audio(type="numpy", streaming=True, autoplay=True, label="Generated audio response", show_label=True)
183
+
184
+ clear_btn = gr.ClearButton([chatbot, audio])
185
+
186
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
187
+ bot, chatbot, chatbot
188
+ ).then(generate_speech, chatbot, audio)
189
+
190
+ txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
191
+
192
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
193
+ bot, chatbot, chatbot
194
+ ).then(generate_speech, chatbot, audio)
195
+
196
+ txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
197
+
198
+ file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then(
199
+ bot, chatbot, chatbot
200
+ ).then(generate_speech, chatbot, audio)
201
+
202
+
203
+ gr.Markdown("""
204
+ This Space demonstrates how to speak to a chatbot, based solely on open-source models.
205
+ It relies on 3 models:
206
+ 1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-large-v2) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
207
+ 2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
208
+ 3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
209
+ Note:
210
+ - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""")
211
+ demo.queue()
212
+ demo.launch(debug=True)