matthartman commited on
Commit
294ffc2
·
verified ·
1 Parent(s): ddf06c6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_cloudflare_turn_credentials_async
5
+ from threading import Thread
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from transformers.generation.streamers import TextIteratorStreamer
8
+
9
+ MODEL_ID = "google/gemma-3-27b-it"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_ID,
13
+ device_map="auto",
14
+ torch_dtype=torch.float16,
15
+ )
16
+
17
+ @spaces.GPU(time_limit=120)
18
+ def generate(data: WebRTCData, history, system_prompt="", max_new_tokens=512):
19
+ text = data.textbox
20
+ history.append({"role": "user", "content": text})
21
+ yield AdditionalOutputs(history)
22
+
23
+ messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
24
+ messages.extend(history)
25
+
26
+ inputs = tokenizer.apply_chat_template(
27
+ messages,
28
+ add_generation_prompt=True,
29
+ return_tensors="pt",
30
+ tokenize=True,
31
+ ).to(model.device)
32
+
33
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
34
+ gen_kwargs = dict(
35
+ input_ids=inputs,
36
+ streamer=streamer,
37
+ max_new_tokens=max_new_tokens,
38
+ do_sample=False,
39
+ )
40
+ Thread(target=model.generate, kwargs=gen_kwargs).start()
41
+
42
+ new_message = {"role": "assistant", "content": ""}
43
+ for token in streamer:
44
+ new_message["content"] += token
45
+ yield AdditionalOutputs(history + [new_message])
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ chatbot = gr.Chatbot(type="messages")
50
+ webrtc = WebRTC(
51
+ modality="audio",
52
+ mode="send",
53
+ variant="textbox",
54
+ rtc_configuration=get_cloudflare_turn_credentials_async,
55
+ )
56
+ with gr.Accordion("Settings", open=False):
57
+ system_prompt = gr.Textbox(
58
+ "You are a helpful assistant.", label="System prompt"
59
+ )
60
+ max_new_tokens = gr.Slider(50, 1500, 700, label="Max new tokens")
61
+
62
+ webrtc.stream(
63
+ ReplyOnPause(generate),
64
+ inputs=[webrtc, chatbot, system_prompt, max_new_tokens],
65
+ outputs=[chatbot],
66
+ concurrency_limit=100,
67
+ )
68
+ webrtc.on_additional_outputs(
69
+ lambda old, new: new, inputs=[chatbot], outputs=[chatbot]
70
+ )
71
+
72
+ if __name__ == "__main__":
73
+ demo.launch(ssr_mode=False)