freddyaboulton HF Staff commited on
Commit
1255de4
·
verified ·
1 Parent(s): 1e0a351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -8,11 +8,11 @@ import av
8
  import gradio as gr
9
  import spaces
10
  import torch
11
- from gradio.utils import get_upload_folder
12
  from gradio.processing_utils import save_audio_to_cache
 
13
  from transformers import AutoModelForImageTextToText, AutoProcessor
14
  from transformers.generation.streamers import TextIteratorStreamer
15
- from fastrtc import ReplyOnPause, WebRTCData, WebRTC, AdditionalOutputs, get_hf_turn_credentials
16
 
17
  model_id = "google/gemma-3n-E4B-it"
18
 
@@ -202,12 +202,19 @@ def _generate(message: dict, history: list[dict], system_prompt: str = "", max_n
202
 
203
  @spaces.GPU(time_limit=120)
204
  def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
205
- message = {"text": data.textbox, "files": [save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder())]}
 
 
 
 
 
 
 
 
206
  new_message = {"role": "assistant", "content": ""}
207
  for output in _generate(message, history, system_prompt, max_new_tokens):
208
  new_message["content"] += output
209
  yield AdditionalOutputs(history + [new_message])
210
-
211
 
212
 
213
  with gr.Blocks() as demo:
@@ -217,12 +224,12 @@ with gr.Blocks() as demo:
217
  mode="send",
218
  variant="textbox",
219
  rtc_configuration=get_hf_turn_credentials,
220
- server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30)
221
  )
222
  with gr.Accordion(label="Additional Inputs"):
223
  sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
224
  slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
225
- image = gr.Image()
226
 
227
  webrtc.stream(
228
  ReplyOnPause(generate), # type: ignore
@@ -230,9 +237,7 @@ with gr.Blocks() as demo:
230
  outputs=[chatbot],
231
  concurrency_limit=100,
232
  )
233
- webrtc.on_additional_outputs(
234
- lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100
235
- )
236
 
237
  if __name__ == "__main__":
238
  demo.launch()
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_hf_turn_credentials
12
  from gradio.processing_utils import save_audio_to_cache
13
+ from gradio.utils import get_upload_folder
14
  from transformers import AutoModelForImageTextToText, AutoProcessor
15
  from transformers.generation.streamers import TextIteratorStreamer
 
16
 
17
  model_id = "google/gemma-3n-E4B-it"
18
 
 
202
 
203
  @spaces.GPU(time_limit=120)
204
  def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
205
+ files = []
206
+ if data.audio is not None and data.audio[1].size > 0:
207
+ files.append(save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder()))
208
+ if image is None:
209
+ files.append(image)
210
+ message = {
211
+ "text": data.textbox,
212
+ "files": [],
213
+ }
214
  new_message = {"role": "assistant", "content": ""}
215
  for output in _generate(message, history, system_prompt, max_new_tokens):
216
  new_message["content"] += output
217
  yield AdditionalOutputs(history + [new_message])
 
218
 
219
 
220
  with gr.Blocks() as demo:
 
224
  mode="send",
225
  variant="textbox",
226
  rtc_configuration=get_hf_turn_credentials,
227
+ server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30),
228
  )
229
  with gr.Accordion(label="Additional Inputs"):
230
  sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
231
  slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
232
+ image = gr.Image(type="filepath")
233
 
234
  webrtc.stream(
235
  ReplyOnPause(generate), # type: ignore
 
237
  outputs=[chatbot],
238
  concurrency_limit=100,
239
  )
240
+ webrtc.on_additional_outputs(lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100)
 
 
241
 
242
  if __name__ == "__main__":
243
  demo.launch()