Luongsosad commited on
Commit
191e3d6
·
1 Parent(s): 807118a
Files changed (2) hide show
  1. app.py +567 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
+ import json
5
+ import base64
6
+ from PIL import Image
7
+ import io
8
+
9
+ ACCESS_TOKEN = os.getenv("HF_TOKEN")
10
+ print("Access token loaded.")
11
+
12
+ # Function to encode image to base64
13
+ def encode_image(image_path):
14
+ if not image_path:
15
+ print("No image path provided")
16
+ return None
17
+
18
+ try:
19
+ print(f"Encoding image from path: {image_path}")
20
+
21
+ # If it's already a PIL Image
22
+ if isinstance(image_path, Image.Image):
23
+ image = image_path
24
+ else:
25
+ # Try to open the image file
26
+ image = Image.open(image_path)
27
+
28
+ # Convert to RGB if image has an alpha channel (RGBA)
29
+ if image.mode == 'RGBA':
30
+ image = image.convert('RGB')
31
+
32
+ # Encode to base64
33
+ buffered = io.BytesIO()
34
+ image.save(buffered, format="JPEG")
35
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
36
+ print("Image encoded successfully")
37
+ return img_str
38
+ except Exception as e:
39
+ print(f"Error encoding image: {e}")
40
+ return None
41
+
42
+ def respond(
43
+ message,
44
+ image_files, # Changed parameter name and structure
45
+ history: list[tuple[str, str]],
46
+ system_message,
47
+ max_tokens,
48
+ temperature,
49
+ top_p,
50
+ frequency_penalty,
51
+ seed,
52
+ provider,
53
+ custom_api_key,
54
+ custom_model,
55
+ model_search_term,
56
+ selected_model
57
+ ):
58
+ print(f"Received message: {message}")
59
+ print(f"Received {len(image_files) if image_files else 0} images")
60
+ print(f"History: {history}")
61
+ print(f"System message: {system_message}")
62
+ print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
63
+ print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
64
+ print(f"Selected provider: {provider}")
65
+ print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
66
+ print(f"Selected model (custom_model): {custom_model}")
67
+ print(f"Model search term: {model_search_term}")
68
+ print(f"Selected model from radio: {selected_model}")
69
+
70
+ # Determine which token to use
71
+ token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
72
+
73
+ if custom_api_key.strip() != "":
74
+ print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
75
+ else:
76
+ print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
77
+
78
+ # Initialize the Inference Client with the provider and appropriate token
79
+ client = InferenceClient(token=token_to_use, provider=provider)
80
+ print(f"Hugging Face Inference Client initialized with {provider} provider.")
81
+
82
+ # Convert seed to None if -1 (meaning random)
83
+ if seed == -1:
84
+ seed = None
85
+
86
+ # Create multimodal content if images are present
87
+ if image_files and len(image_files) > 0:
88
+ # Process the user message to include images
89
+ user_content = []
90
+
91
+ # Add text part if there is any
92
+ if message and message.strip():
93
+ user_content.append({
94
+ "type": "text",
95
+ "text": message
96
+ })
97
+
98
+ # Add image parts
99
+ for img in image_files:
100
+ if img is not None:
101
+ # Get raw image data from path
102
+ try:
103
+ encoded_image = encode_image(img)
104
+ if encoded_image:
105
+ user_content.append({
106
+ "type": "image_url",
107
+ "image_url": {
108
+ "url": f"data:image/jpeg;base64,{encoded_image}"
109
+ }
110
+ })
111
+ except Exception as e:
112
+ print(f"Error encoding image: {e}")
113
+ else:
114
+ # Text-only message
115
+ user_content = message
116
+
117
+ # Prepare messages in the format expected by the API
118
+ messages = [{"role": "system", "content": system_message}]
119
+ print("Initial messages array constructed.")
120
+
121
+ # Add conversation history to the context
122
+ for val in history:
123
+ user_part = val[0]
124
+ assistant_part = val[1]
125
+ if user_part:
126
+ # Handle both text-only and multimodal messages in history
127
+ if isinstance(user_part, tuple) and len(user_part) == 2:
128
+ # This is a multimodal message with text and images
129
+ history_content = []
130
+ if user_part[0]: # Text
131
+ history_content.append({
132
+ "type": "text",
133
+ "text": user_part[0]
134
+ })
135
+
136
+ for img in user_part[1]: # Images
137
+ if img:
138
+ try:
139
+ encoded_img = encode_image(img)
140
+ if encoded_img:
141
+ history_content.append({
142
+ "type": "image_url",
143
+ "image_url": {
144
+ "url": f"data:image/jpeg;base64,{encoded_img}"
145
+ }
146
+ })
147
+ except Exception as e:
148
+ print(f"Error encoding history image: {e}")
149
+
150
+ messages.append({"role": "user", "content": history_content})
151
+ else:
152
+ # Regular text message
153
+ messages.append({"role": "user", "content": user_part})
154
+ print(f"Added user message to context (type: {type(user_part)})")
155
+
156
+ if assistant_part:
157
+ messages.append({"role": "assistant", "content": assistant_part})
158
+ print(f"Added assistant message to context: {assistant_part}")
159
+
160
+ # Append the latest user message
161
+ messages.append({"role": "user", "content": user_content})
162
+ print(f"Latest user message appended (content type: {type(user_content)})")
163
+
164
+ # Determine which model to use, prioritizing custom_model if provided
165
+ model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
166
+ print(f"Model selected for inference: {model_to_use}")
167
+
168
+ # Start with an empty string to build the response as tokens stream in
169
+ response = ""
170
+ print(f"Sending request to {provider} provider.")
171
+
172
+ # Prepare parameters for the chat completion request
173
+ parameters = {
174
+ "max_tokens": max_tokens,
175
+ "temperature": temperature,
176
+ "top_p": top_p,
177
+ "frequency_penalty": frequency_penalty,
178
+ }
179
+
180
+ if seed is not None:
181
+ parameters["seed"] = seed
182
+
183
+ # Use the InferenceClient for making the request
184
+ try:
185
+ # Create a generator for the streaming response
186
+ stream = client.chat_completion(
187
+ model=model_to_use,
188
+ messages=messages,
189
+ stream=True,
190
+ **parameters
191
+ )
192
+
193
+ print("Received tokens: ", end="", flush=True)
194
+
195
+ # Process the streaming response
196
+ for chunk in stream:
197
+ if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
198
+ # Extract the content from the response
199
+ if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
200
+ token_text = chunk.choices[0].delta.content
201
+ if token_text:
202
+ print(token_text, end="", flush=True)
203
+ response += token_text
204
+ yield response
205
+
206
+ print()
207
+ except Exception as e:
208
+ print(f"Error during inference: {e}")
209
+ response += f"\nError: {str(e)}"
210
+ yield response
211
+
212
+ print("Completed response generation.")
213
+
214
+ # Function to validate provider selection based on BYOK
215
+ def validate_provider(api_key, provider):
216
+ if not api_key.strip() and provider != "hf-inference":
217
+ return gr.update(value="hf-inference")
218
+ return gr.update(value=provider)
219
+
220
+ # GRADIO UI
221
+ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
222
+ # Create the chatbot component
223
+ chatbot = gr.Chatbot(
224
+ height=600,
225
+ show_copy_button=True,
226
+ placeholder="Select a model and begin chatting. Now supports multiple inference providers and multimodal inputs",
227
+ layout="panel"
228
+ )
229
+ print("Chatbot interface created.")
230
+
231
+ # Multimodal textbox for messages (combines text and file uploads)
232
+ msg = gr.MultimodalTextbox(
233
+ placeholder="Type a message or upload images...",
234
+ show_label=False,
235
+ container=False,
236
+ scale=12,
237
+ file_types=["image"],
238
+ file_count="multiple",
239
+ sources=["upload"]
240
+ )
241
+
242
+ # Note: We're removing the separate submit button since MultimodalTextbox has its own
243
+
244
+ # Create accordion for settings
245
+ with gr.Accordion("Settings", open=False):
246
+ # System message
247
+ system_message_box = gr.Textbox(
248
+ value="You are a helpful AI assistant that can understand images and text.",
249
+ placeholder="You are a helpful assistant.",
250
+ label="System Prompt"
251
+ )
252
+
253
+ # Generation parameters
254
+ with gr.Row():
255
+ with gr.Column():
256
+ max_tokens_slider = gr.Slider(
257
+ minimum=1,
258
+ maximum=4096,
259
+ value=512,
260
+ step=1,
261
+ label="Max tokens"
262
+ )
263
+
264
+ temperature_slider = gr.Slider(
265
+ minimum=0.1,
266
+ maximum=4.0,
267
+ value=0.7,
268
+ step=0.1,
269
+ label="Temperature"
270
+ )
271
+
272
+ top_p_slider = gr.Slider(
273
+ minimum=0.1,
274
+ maximum=1.0,
275
+ value=0.95,
276
+ step=0.05,
277
+ label="Top-P"
278
+ )
279
+
280
+ with gr.Column():
281
+ frequency_penalty_slider = gr.Slider(
282
+ minimum=-2.0,
283
+ maximum=2.0,
284
+ value=0.0,
285
+ step=0.1,
286
+ label="Frequency Penalty"
287
+ )
288
+
289
+ seed_slider = gr.Slider(
290
+ minimum=-1,
291
+ maximum=65535,
292
+ value=-1,
293
+ step=1,
294
+ label="Seed (-1 for random)"
295
+ )
296
+
297
+ # Provider selection
298
+ providers_list = [
299
+ "hf-inference", # Default Hugging Face Inference
300
+ "cerebras", # Cerebras provider
301
+ "together", # Together AI
302
+ "sambanova", # SambaNova
303
+ "novita", # Novita AI
304
+ "cohere", # Cohere
305
+ "fireworks-ai", # Fireworks AI
306
+ "hyperbolic", # Hyperbolic
307
+ "nebius", # Nebius
308
+ ]
309
+
310
+ provider_radio = gr.Radio(
311
+ choices=providers_list,
312
+ value="hf-inference",
313
+ label="Inference Provider",
314
+ )
315
+
316
+ # New BYOK textbox
317
+ byok_textbox = gr.Textbox(
318
+ value="",
319
+ label="BYOK (Bring Your Own Key)",
320
+ info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.",
321
+ placeholder="Enter your Hugging Face API token",
322
+ type="password" # Hide the API key for security
323
+ )
324
+
325
+ # Custom model box
326
+ custom_model_box = gr.Textbox(
327
+ value="",
328
+ label="Custom Model",
329
+ info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
330
+ placeholder="meta-llama/Llama-3.3-70B-Instruct"
331
+ )
332
+
333
+ # Model search
334
+ model_search_box = gr.Textbox(
335
+ label="Filter Models",
336
+ placeholder="Search for a featured model...",
337
+ lines=1
338
+ )
339
+
340
+ # Featured models list
341
+ # Updated to include multimodal models
342
+ models_list = [
343
+ "meta-llama/Llama-3.2-11B-Vision-Instruct",
344
+ "meta-llama/Llama-3.3-70B-Instruct",
345
+ "meta-llama/Llama-3.1-70B-Instruct",
346
+ "meta-llama/Llama-3.0-70B-Instruct",
347
+ "meta-llama/Llama-3.2-3B-Instruct",
348
+ "meta-llama/Llama-3.2-1B-Instruct",
349
+ "meta-llama/Llama-3.1-8B-Instruct",
350
+ "NousResearch/Hermes-3-Llama-3.1-8B",
351
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
352
+ "mistralai/Mistral-Nemo-Instruct-2407",
353
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
354
+ "mistralai/Mistral-7B-Instruct-v0.3",
355
+ "mistralai/Mistral-7B-Instruct-v0.2",
356
+ "Qwen/Qwen3-235B-A22B",
357
+ "Qwen/Qwen3-32B",
358
+ "Qwen/Qwen2.5-72B-Instruct",
359
+ "Qwen/Qwen2.5-3B-Instruct",
360
+ "Qwen/Qwen2.5-0.5B-Instruct",
361
+ "Qwen/QwQ-32B",
362
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
363
+ "microsoft/Phi-3.5-mini-instruct",
364
+ "microsoft/Phi-3-mini-128k-instruct",
365
+ "microsoft/Phi-3-mini-4k-instruct",
366
+ ]
367
+
368
+ featured_model_radio = gr.Radio(
369
+ label="Select a model below",
370
+ choices=models_list,
371
+ value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model
372
+ interactive=True
373
+ )
374
+
375
+ gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
376
+
377
+ # Chat history state
378
+ chat_history = gr.State([])
379
+
380
+ # Function to filter models
381
+ def filter_models(search_term):
382
+ print(f"Filtering models with search term: {search_term}")
383
+ filtered = [m for m in models_list if search_term.lower() in m.lower()]
384
+ print(f"Filtered models: {filtered}")
385
+ return gr.update(choices=filtered)
386
+
387
+ # Function to set custom model from radio
388
+ def set_custom_model_from_radio(selected):
389
+ print(f"Featured model selected: {selected}")
390
+ return selected
391
+
392
+ # Function for the chat interface
393
+ def user(user_message, history):
394
+ # Debug logging for troubleshooting
395
+ print(f"User message received: {user_message}")
396
+
397
+ # Skip if message is empty (no text and no files)
398
+ if not user_message or (not user_message.get("text") and not user_message.get("files")):
399
+ print("Empty message, skipping")
400
+ return history
401
+
402
+ # Prepare multimodal message format
403
+ text_content = user_message.get("text", "").strip()
404
+ files = user_message.get("files", [])
405
+
406
+ print(f"Text content: {text_content}")
407
+ print(f"Files: {files}")
408
+
409
+ # If both text and files are empty, skip
410
+ if not text_content and not files:
411
+ print("No content to display")
412
+ return history
413
+
414
+ # Add message with images to history
415
+ if files and len(files) > 0:
416
+ # Add text message first if it exists
417
+ if text_content:
418
+ # Add a separate text message
419
+ print(f"Adding text message: {text_content}")
420
+ history.append([text_content, None])
421
+
422
+ # Then add each image file separately
423
+ for file_path in files:
424
+ if file_path and isinstance(file_path, str):
425
+ print(f"Adding image: {file_path}")
426
+ # Add image as a separate message with no text
427
+ history.append([f"![Image]({file_path})", None])
428
+
429
+ return history
430
+ else:
431
+ # For text-only messages
432
+ print(f"Adding text-only message: {text_content}")
433
+ history.append([text_content, None])
434
+ return history
435
+
436
+ # Define bot response function
437
+ def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
438
+ # Check if history is valid
439
+ if not history or len(history) == 0:
440
+ print("No history to process")
441
+ return history
442
+
443
+ # Get the most recent message and detect if it's an image
444
+ user_message = history[-1][0]
445
+ print(f"Processing user message: {user_message}")
446
+
447
+ is_image = False
448
+ image_path = None
449
+ text_content = user_message
450
+
451
+ # Check if this is an image message (marked with ![Image])
452
+ if isinstance(user_message, str) and user_message.startswith("![Image]("):
453
+ is_image = True
454
+ # Extract image path from markdown format ![Image](path)
455
+ image_path = user_message.replace("![Image](", "").replace(")", "")
456
+ print(f"Image detected: {image_path}")
457
+ text_content = "" # No text for image-only messages
458
+
459
+ # Look back for text context if this is an image
460
+ text_context = ""
461
+ if is_image and len(history) > 1:
462
+ # Use the previous message as context if it's text
463
+ prev_message = history[-2][0]
464
+ if isinstance(prev_message, str) and not prev_message.startswith("![Image]("):
465
+ text_context = prev_message
466
+ print(f"Using text context from previous message: {text_context}")
467
+
468
+ # Process message through respond function
469
+ history[-1][1] = ""
470
+
471
+ # Use either the image or text for the API
472
+ if is_image:
473
+ # For image messages
474
+ for response in respond(
475
+ text_context, # Text context from previous message if any
476
+ [image_path], # Current image
477
+ history[:-1], # Previous history
478
+ system_msg,
479
+ max_tokens,
480
+ temperature,
481
+ top_p,
482
+ freq_penalty,
483
+ seed,
484
+ provider,
485
+ api_key,
486
+ custom_model,
487
+ search_term,
488
+ selected_model
489
+ ):
490
+ history[-1][1] = response
491
+ yield history
492
+ else:
493
+ # For text-only messages
494
+ for response in respond(
495
+ text_content, # Text message
496
+ None, # No image
497
+ history[:-1], # Previous history
498
+ system_msg,
499
+ max_tokens,
500
+ temperature,
501
+ top_p,
502
+ freq_penalty,
503
+ seed,
504
+ provider,
505
+ api_key,
506
+ custom_model,
507
+ search_term,
508
+ selected_model
509
+ ):
510
+ history[-1][1] = response
511
+ yield history
512
+
513
+ # Event handlers - only using the MultimodalTextbox's built-in submit functionality
514
+ msg.submit(
515
+ user,
516
+ [msg, chatbot],
517
+ [chatbot],
518
+ queue=False
519
+ ).then(
520
+ bot,
521
+ [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
522
+ frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
523
+ model_search_box, featured_model_radio],
524
+ [chatbot]
525
+ ).then(
526
+ lambda: {"text": "", "files": []}, # Clear inputs after submission
527
+ None,
528
+ [msg]
529
+ )
530
+
531
+ # Connect the model filter to update the radio choices
532
+ model_search_box.change(
533
+ fn=filter_models,
534
+ inputs=model_search_box,
535
+ outputs=featured_model_radio
536
+ )
537
+ print("Model search box change event linked.")
538
+
539
+ # Connect the featured model radio to update the custom model box
540
+ featured_model_radio.change(
541
+ fn=set_custom_model_from_radio,
542
+ inputs=featured_model_radio,
543
+ outputs=custom_model_box
544
+ )
545
+ print("Featured model radio button change event linked.")
546
+
547
+ # Connect the BYOK textbox to validate provider selection
548
+ byok_textbox.change(
549
+ fn=validate_provider,
550
+ inputs=[byok_textbox, provider_radio],
551
+ outputs=provider_radio
552
+ )
553
+ print("BYOK textbox change event linked.")
554
+
555
+ # Also validate provider when the radio changes to ensure consistency
556
+ provider_radio.change(
557
+ fn=validate_provider,
558
+ inputs=[byok_textbox, provider_radio],
559
+ outputs=provider_radio
560
+ )
561
+ print("Provider radio button change event linked.")
562
+
563
+ print("Gradio interface initialized.")
564
+
565
+ if __name__ == "__main__":
566
+ print("Launching the demo application.")
567
+ demo.launch(show_api=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ openai