BryanBradfo commited on
Commit
ade8c92
·
1 Parent(s): 43339e8

try of adding icons

Browse files
Files changed (1) hide show
  1. app.py +271 -173
app.py CHANGED
@@ -1,182 +1,280 @@
1
  import os
2
  import queue
3
- import gradio as gr
4
  from threading import Thread
5
- from typing import List, Tuple
6
-
7
- ############################################################
8
- # EXAMPLE: Dummy function to simulate streaming generation #
9
- ############################################################
10
- def stream_generate(message: str, history: List[Tuple[str, str]]):
11
- """
12
- A sample streaming generator. Replace with your actual
13
- model code & logic. This uses a simple placeholder
14
- approach that yields tokens from a static text for demo.
15
- """
16
- # Simulate conversation: user -> system -> user -> ...
17
- # We'll just produce a contrived "thinking" response.
18
- # In your real code, you'd do something like:
19
- # for token in your_model_stream:
20
- # yield token
21
- # or wrap your generation with a queue and catch queue.Empty
22
-
23
- response_text = "This is an example answer to your query.\nFeel free to replace me with real model output!"
24
- output = ""
25
- for char in response_text:
26
- output += char
27
- yield output
28
-
29
-
30
- def respond(message: str, chat_history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
31
- """
32
- Called by the UI to get the next response from the chatbot.
33
- Returns updated chat history and an empty string to clear the user input.
34
- """
35
- # Add the user’s message to conversation
36
- chat_history.append((message, ""))
37
-
38
- # We’ll stream the response from our generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  outputs = []
40
  try:
41
- for partial_text in stream_generate(message, chat_history):
42
- outputs = chat_history[:-1] + [(message, partial_text)]
43
- # Update the chatbot in real-time:
44
- yield outputs, ""
45
  except queue.Empty:
46
- # End of stream. Avoid traceback.
47
- pass
48
-
49
- # Once done, finalize the last user -> bot pair
50
- final_answer = outputs[-1][1] if outputs else "No response."
51
- chat_history[-1] = (message, final_answer)
52
- yield chat_history, ""
53
-
54
-
55
- ############################################################
56
- # GRADIO BLOCKS UI #
57
- ############################################################
58
- def launch_app():
59
- # A custom CSS snippet to style backgrounds, suggestions, bubbles, etc.
60
- custom_css = """
61
- :root {
62
- --gradient-start: #66AEEF; /* lighter top */
63
- --gradient-end: #F0F8FF; /* very light at bottom */
64
- }
65
- html, body {
66
- margin: 0;
67
- padding: 0;
68
- background: linear-gradient(to bottom, var(--gradient-start), var(--gradient-end));
69
- font-family: "Helvetica", sans-serif;
70
- color: #333;
71
- }
72
- h1 {
73
- text-align: center;
74
- color: #fff;
75
- margin-top: 1.2em;
76
- }
77
- /* Chatbot styling */
78
- .gradio-container {
79
- max-width: 800px;
80
- margin: 0 auto;
81
- padding-bottom: 2rem;
82
- }
83
- .chatbot {
84
- background-color: #F8FDFF !important;
85
- }
86
- .chatbot .message {
87
- border-radius: 8px;
88
- margin: 6px;
89
- padding: 10px;
90
- line-height: 1.4;
91
- position: relative;
92
- }
93
- .chatbot .user .chat-avatar {
94
- background: url('user.png') center center no-repeat;
95
- background-size: cover;
96
- }
97
- .chatbot .bot .chat-avatar {
98
- background: url('gemma.png') center center no-repeat;
99
- background-size: cover;
100
- }
101
- /* Example suggestions row */
102
- .examples-row {
103
- display: flex;
104
- gap: 8px;
105
- flex-wrap: wrap;
106
- justify-content: center;
107
- margin-bottom: 20px;
108
- }
109
- .examples-row button {
110
- background-color: #EAF4FF;
111
- border: 1px solid #66AEEF;
112
- border-radius: 8px;
113
- padding: 8px 14px;
114
- color: #333;
115
- cursor: pointer;
116
- }
117
- .examples-row button:hover {
118
- background-color: #D8ECFE;
119
- }
120
- """
121
-
122
- with gr.Blocks(css=custom_css) as demo:
123
- gr.Markdown("<h1>Hi, I'm Gemma-2 (2B) </h1>")
124
-
125
- with gr.Row():
126
- # We store conversation in a state variable
127
- chat_state = gr.State([])
128
-
129
- # A row of clickable suggestions
130
- with gr.Row(elem_id="examples-row", classes="examples-row"):
131
- # For each suggestion, we have a button
132
- examples = [
133
- "Hello there! How are you doing?",
134
- "Can you explain briefly what Python is?",
135
- "Explain the plot of Cinderella in a sentence.",
136
- "How many hours does it take a man to eat a Helicopter?",
137
- "Write a 100-word article on 'Benefits of Open-Source in AI research'"
138
- ]
139
-
140
- # We'll create a small function for each button to set the textbox value
141
- def set_example_text(example):
142
- return example
143
-
144
- suggestion_buttons = []
145
- for ex in examples:
146
- btn = gr.Button(ex)
147
- btn.click(fn=set_example_text, inputs=[], outputs=[],
148
- _js=f"(x) => {{ document.querySelector('#user_input').value = '{ex}'; }}")
149
- suggestion_buttons.append(btn)
150
-
151
- # Our custom chatbot interface
152
- chatbot = gr.Chatbot(
153
- label="Gemma Chat",
154
- elem_id="chat_window",
155
- height=400,
156
- avatar_images=("user.png","gemma.png"),
157
- # optionally show_copy_button=True,
158
- )
159
-
160
- # A row with user input + submit
161
- with gr.Row():
162
- user_input = gr.Textbox(
163
- label="Your message:",
164
- placeholder="Type something...",
165
- lines=2,
166
- elem_id="user_input"
167
- )
168
- submit_btn = gr.Button("Send")
169
-
170
- # Link the `respond` function to handle the conversation
171
- submit_btn.click(
172
- fn=respond,
173
- inputs=[user_input, chat_state],
174
- outputs=[chatbot, user_input],
175
- queue=True
176
- )
177
-
178
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
- launch_app()
 
 
1
  import os
2
  import queue
3
+ from collections.abc import Iterator
4
  from threading import Thread
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+
11
+ DESCRIPTION = """\
12
+ <h1 style="text-align: center;">Hi, I'm Gemma 2 (2B) 👋</h1>
13
+
14
+ This is a demo of <strong>google/gemma-2-2b-it</strong> fine-tuned for instruction following. For more details, please check <a href="https://huggingface.co/blog/gemma2" target="_blank">the post</a>.
15
+
16
+ 👉 Looking for a larger and more powerful version? Try the 27B version in <a href="https://huggingface.co/chat/models/google/gemma-2-27b-it" target="_blank">HuggingChat</a> and the 9B version in <a href="https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it" target="_blank">this Space</a>.
17
+ """
18
+
19
+ MAX_MAX_NEW_TOKENS = 2048
20
+ DEFAULT_MAX_NEW_TOKENS = 1024
21
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
+
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ model_id = "google/gemma-2-2b-it"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ )
32
+ model.config.sliding_window = 4096
33
+ model.eval()
34
+
35
+
36
+ @spaces.GPU
37
+ def generate(
38
+ message: str,
39
+ chat_history: list[dict],
40
+ max_new_tokens: int = 1024,
41
+ temperature: float = 0.6,
42
+ top_p: float = 0.9,
43
+ top_k: int = 50,
44
+ repetition_penalty: float = 1.2,
45
+ ) -> Iterator[str]:
46
+ """Generate text from the model and stream tokens back to the UI."""
47
+ conversation = chat_history.copy()
48
+ conversation.append({"role": "user", "content": message})
49
+
50
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
+ input_ids = input_ids.to(model.device)
55
+
56
+ # Stream out tokens
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ {"input_ids": input_ids},
60
+ streamer=streamer,
61
+ max_new_tokens=max_new_tokens,
62
+ do_sample=True,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ temperature=temperature,
66
+ num_beams=1,
67
+ repetition_penalty=repetition_penalty,
68
+ )
69
+
70
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
71
+ t.start()
72
+
73
  outputs = []
74
  try:
75
+ for text in streamer:
76
+ outputs.append(text)
77
+ yield "".join(outputs)
 
78
  except queue.Empty:
79
+ # End of stream; avoid traceback
80
+ return
81
+
82
+
83
+ # Below is the custom Gemini-themed CSS,
84
+ # extended to add user/bot icons and style the built-in "examples" as clickable bubbles.
85
+ gemini_css = """
86
+ :root {
87
+ --gradient-start: #66AEEF; /* lighter top */
88
+ --gradient-end: #F0F8FF; /* very light at bottom */
89
+ }
90
+
91
+ /* Overall page & container background gradient */
92
+ html, body, .gradio-container {
93
+ margin: 0;
94
+ padding: 0;
95
+ background: linear-gradient(to bottom, var(--gradient-start), var(--gradient-end));
96
+ font-family: "Helvetica", sans-serif;
97
+ color: #333; /* dark gray for better contrast */
98
+ }
99
+
100
+ /* Make anchor (link) text a clearly visible dark blue */
101
+ a, a:visited {
102
+ color: #02497A !important;
103
+ text-decoration: underline;
104
+ }
105
+
106
+ /* Center the top headings in the description */
107
+ .gradio-container h1 {
108
+ margin-top: 0.8em;
109
+ margin-bottom: 0.5em;
110
+ text-align: center;
111
+ color: #fff; /* White text on top gradient for pop */
112
+ }
113
+
114
+ /* Chat container background: a very light blue so it's distinct from the outer gradient */
115
+ .chat-interface, .chat-interface .wrap {
116
+ background-color: #F8FDFF !important;
117
+ }
118
+
119
+ /* Remove harsh frames around chat messages */
120
+ .chat-message {
121
+ border: none !important;
122
+ position: relative;
123
+ }
124
+
125
+ /* Icons for user and bot messages */
126
+ .chat-message.user::before {
127
+ content: '';
128
+ display: inline-block;
129
+ background: url('user.png') center center no-repeat;
130
+ background-size: cover;
131
+ width: 24px;
132
+ height: 24px;
133
+ margin-right: 8px;
134
+ vertical-align: middle;
135
+ }
136
+
137
+ .chat-message.bot::before {
138
+ content: '';
139
+ display: inline-block;
140
+ background: url('gemma.png') center center no-repeat;
141
+ background-size: cover;
142
+ width: 24px;
143
+ height: 24px;
144
+ margin-right: 8px;
145
+ vertical-align: middle;
146
+ }
147
+
148
+ /* User bubble: a deeper blue with white text */
149
+ .chat-message.user {
150
+ background-color: #0284C7 !important;
151
+ color: #FFFFFF !important;
152
+ border-radius: 8px;
153
+ padding: 8px 12px;
154
+ margin: 6px 0;
155
+ }
156
+
157
+ /* Bot bubble: very light blue with darker text */
158
+ .chat-message.bot {
159
+ background-color: #EFF8FF !important;
160
+ color: #333 !important;
161
+ border-radius: 8px;
162
+ padding: 8px 12px;
163
+ margin: 6px 0;
164
+ }
165
+
166
+ /* Chat input area */
167
+ .chat-input textarea {
168
+ background-color: #FFFFFF;
169
+ color: #333;
170
+ border: 1px solid #66AEEF;
171
+ border-radius: 6px;
172
+ padding: 8px;
173
+ }
174
+
175
+ /* Sliders & other controls */
176
+ form.sliders input[type="range"] {
177
+ accent-color: #66AEEF;
178
+ }
179
+ form.sliders label {
180
+ color: #333;
181
+ }
182
+
183
+ .gradio-button, .chat-send-btn {
184
+ background-color: #0284C7 !important;
185
+ color: #FFFFFF !important;
186
+ border-radius: 5px;
187
+ border: none;
188
+ cursor: pointer;
189
+ }
190
+ .gradio-button:hover, .chat-send-btn:hover {
191
+ background-color: #026FA6 !important;
192
+ }
193
+
194
+ /* Style the example "pill" buttons (the built-in ChatInterface examples) */
195
+ .gr-examples {
196
+ display: flex !important;
197
+ flex-wrap: wrap;
198
+ gap: 16px;
199
+ justify-content: center;
200
+ margin-bottom: 1em !important;
201
+ }
202
+ .gr-examples button.example {
203
+ background-color: #EFF8FF !important;
204
+ border: 1px solid #66AEEF !important;
205
+ border-radius: 8px !important;
206
+ color: #333 !important;
207
+ padding: 10px 16px !important;
208
+ cursor: pointer !important;
209
+ transition: background-color 0.2s !important;
210
+ }
211
+ .gr-examples button.example:hover {
212
+ background-color: #E0F2FF !important;
213
+ }
214
+
215
+ /* Additional spacing / small tweaks */
216
+ #duplicate-button {
217
+ margin: auto;
218
+ background: #1565c0;
219
+ border-radius: 100vh;
220
+ color: #fff;
221
+ }
222
+ """
223
 
224
+ demo = gr.ChatInterface(
225
+ fn=generate,
226
+ additional_inputs=[
227
+ gr.Slider(
228
+ label="Max new tokens",
229
+ minimum=1,
230
+ maximum=MAX_MAX_NEW_TOKENS,
231
+ step=1,
232
+ value=DEFAULT_MAX_NEW_TOKENS,
233
+ ),
234
+ gr.Slider(
235
+ label="Temperature",
236
+ minimum=0.1,
237
+ maximum=4.0,
238
+ step=0.1,
239
+ value=0.6,
240
+ ),
241
+ gr.Slider(
242
+ label="Top-p (nucleus sampling)",
243
+ minimum=0.05,
244
+ maximum=1.0,
245
+ step=0.05,
246
+ value=0.9,
247
+ ),
248
+ gr.Slider(
249
+ label="Top-k",
250
+ minimum=1,
251
+ maximum=1000,
252
+ step=1,
253
+ value=50,
254
+ ),
255
+ gr.Slider(
256
+ label="Repetition penalty",
257
+ minimum=1.0,
258
+ maximum=2.0,
259
+ step=0.05,
260
+ value=1.2,
261
+ ),
262
+ ],
263
+ stop_btn=None,
264
+ examples=[
265
+ ["Hello there! How are you doing?"],
266
+ ["Can you explain briefly to me what is the Python programming language?"],
267
+ ["Explain the plot of Cinderella in a sentence."],
268
+ ["How many hours does it take a man to eat a Helicopter?"],
269
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
270
+ ],
271
+ cache_examples=False,
272
+ type="messages",
273
+ description=DESCRIPTION,
274
+ css=gemini_css,
275
+ fill_height=True,
276
+ )
277
 
278
  if __name__ == "__main__":
279
+ # You can queue if you want concurrency or streaming
280
+ demo.queue(max_size=20).launch()