Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,11 @@ import av
|
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
import torch
|
11 |
-
from
|
12 |
from gradio.processing_utils import save_audio_to_cache
|
|
|
13 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
14 |
from transformers.generation.streamers import TextIteratorStreamer
|
15 |
-
from fastrtc import ReplyOnPause, WebRTCData, WebRTC, AdditionalOutputs, get_hf_turn_credentials
|
16 |
|
17 |
model_id = "google/gemma-3n-E4B-it"
|
18 |
|
@@ -202,12 +202,19 @@ def _generate(message: dict, history: list[dict], system_prompt: str = "", max_n
|
|
202 |
|
203 |
@spaces.GPU(time_limit=120)
|
204 |
def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
new_message = {"role": "assistant", "content": ""}
|
207 |
for output in _generate(message, history, system_prompt, max_new_tokens):
|
208 |
new_message["content"] += output
|
209 |
yield AdditionalOutputs(history + [new_message])
|
210 |
-
|
211 |
|
212 |
|
213 |
with gr.Blocks() as demo:
|
@@ -217,12 +224,12 @@ with gr.Blocks() as demo:
|
|
217 |
mode="send",
|
218 |
variant="textbox",
|
219 |
rtc_configuration=get_hf_turn_credentials,
|
220 |
-
server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30)
|
221 |
)
|
222 |
with gr.Accordion(label="Additional Inputs"):
|
223 |
sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
|
224 |
slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
|
225 |
-
image = gr.Image()
|
226 |
|
227 |
webrtc.stream(
|
228 |
ReplyOnPause(generate), # type: ignore
|
@@ -230,9 +237,7 @@ with gr.Blocks() as demo:
|
|
230 |
outputs=[chatbot],
|
231 |
concurrency_limit=100,
|
232 |
)
|
233 |
-
webrtc.on_additional_outputs(
|
234 |
-
lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100
|
235 |
-
)
|
236 |
|
237 |
if __name__ == "__main__":
|
238 |
demo.launch()
|
|
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
import torch
|
11 |
+
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_hf_turn_credentials
|
12 |
from gradio.processing_utils import save_audio_to_cache
|
13 |
+
from gradio.utils import get_upload_folder
|
14 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
15 |
from transformers.generation.streamers import TextIteratorStreamer
|
|
|
16 |
|
17 |
model_id = "google/gemma-3n-E4B-it"
|
18 |
|
|
|
202 |
|
203 |
@spaces.GPU(time_limit=120)
|
204 |
def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
|
205 |
+
files = []
|
206 |
+
if data.audio is not None and data.audio[1].size > 0:
|
207 |
+
files.append(save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder()))
|
208 |
+
if image is None:
|
209 |
+
files.append(image)
|
210 |
+
message = {
|
211 |
+
"text": data.textbox,
|
212 |
+
"files": [],
|
213 |
+
}
|
214 |
new_message = {"role": "assistant", "content": ""}
|
215 |
for output in _generate(message, history, system_prompt, max_new_tokens):
|
216 |
new_message["content"] += output
|
217 |
yield AdditionalOutputs(history + [new_message])
|
|
|
218 |
|
219 |
|
220 |
with gr.Blocks() as demo:
|
|
|
224 |
mode="send",
|
225 |
variant="textbox",
|
226 |
rtc_configuration=get_hf_turn_credentials,
|
227 |
+
server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30),
|
228 |
)
|
229 |
with gr.Accordion(label="Additional Inputs"):
|
230 |
sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
|
231 |
slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
|
232 |
+
image = gr.Image(type="filepath")
|
233 |
|
234 |
webrtc.stream(
|
235 |
ReplyOnPause(generate), # type: ignore
|
|
|
237 |
outputs=[chatbot],
|
238 |
concurrency_limit=100,
|
239 |
)
|
240 |
+
webrtc.on_additional_outputs(lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100)
|
|
|
|
|
241 |
|
242 |
if __name__ == "__main__":
|
243 |
demo.launch()
|