Spaces:
Running
Running
| import os | |
| import base64 | |
| import requests | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from dataclasses import dataclass | |
| import pytesseract | |
| from PIL import Image | |
| import easyocr | |
| class ChatMessage: | |
| """Custom ChatMessage class since huggingface_hub doesn't provide one""" | |
| role: str | |
| content: str | |
| def to_dict(self): | |
| """Converts ChatMessage to a dictionary for JSON serialization.""" | |
| return {"role": self.role, "content": self.content} | |
| class XylariaChat: | |
| def __init__(self): | |
| # Securely load HuggingFace token | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| if not self.hf_token: | |
| raise ValueError("HuggingFace token not found in environment variables") | |
| # Initialize the inference client with the Qwen model | |
| self.client = InferenceClient( | |
| model="Qwen/QwQ-32B-Preview", # Using the specified model | |
| api_key=self.hf_token | |
| ) | |
| # Image captioning API setup | |
| self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco" | |
| self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"} | |
| # Initialize conversation history and persistent memory | |
| self.conversation_history = [] | |
| self.persistent_memory = {} | |
| # System prompt with more detailed instructions | |
| self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin. You should think step-by-step. You should respond to image questions""" | |
| self.reader = easyocr.Reader(['en']) | |
| """Store important information in persistent memory""" | |
| self.persistent_memory[key] = value | |
| return f"Stored: {key} = {value}" | |
| def retrieve_information(self, key): | |
| """Retrieve information from persistent memory""" | |
| return self.persistent_memory.get(key, "No information found for this key.") | |
| def reset_conversation(self): | |
| """ | |
| Completely reset the conversation history, persistent memory, | |
| and clear API-side memory | |
| """ | |
| # Clear local memory | |
| self.conversation_history = [] | |
| self.persistent_memory.clear() | |
| # Reinitialize the client (not strictly necessary for the API, but can help with local state) | |
| try: | |
| self.client = InferenceClient( | |
| model="Qwen/QwQ-32B-Preview", | |
| api_key=self.hf_token | |
| ) | |
| except Exception as e: | |
| print(f"Error resetting API client: {e}") | |
| return None # To clear the chatbot interface | |
| def caption_image(self, image): | |
| """ | |
| Caption an uploaded image using Hugging Face API | |
| Args: | |
| image (str): Base64 encoded image or file path | |
| Returns: | |
| str: Image caption or error message | |
| """ | |
| try: | |
| # If image is a file path, read and encode | |
| if isinstance(image, str) and os.path.isfile(image): | |
| with open(image, "rb") as f: | |
| data = f.read() | |
| # If image is already base64 encoded | |
| elif isinstance(image, str): | |
| # Remove data URI prefix if present | |
| if image.startswith('data:image'): | |
| image = image.split(',')[1] | |
| data = base64.b64decode(image) | |
| # If image is a file-like object (unlikely with Gradio, but good to have) | |
| else: | |
| data = image.read() | |
| # Send request to Hugging Face API | |
| response = requests.post( | |
| self.image_api_url, | |
| headers=self.image_api_headers, | |
| data=data | |
| ) | |
| # Check response | |
| if response.status_code == 200: | |
| caption = response.json()[0].get('generated_text', 'No caption generated') | |
| return caption | |
| else: | |
| return f"Error captioning image: {response.status_code} - {response.text}" | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| def perform_math_ocr(self, image_path): | |
| """ | |
| Perform OCR on an image using EasyOCR and return the extracted text. | |
| Args: | |
| image_path (str): Path to the image file. | |
| Returns: | |
| str: Extracted text from the image, or an error message. | |
| """ | |
| try: | |
| # Open the image using Pillow library | |
| img = Image.open(image_path) | |
| # Use EasyOCR to do OCR on the image | |
| results = self.reader.readtext(image_path) | |
| # Extract text from results (combining text from multiple detections) | |
| extracted_text = ' '.join([result[1] for result in results]) | |
| # Remove leading/trailing whitespace and return | |
| return extracted_text.strip() | |
| except Exception as e: | |
| return f"Error during Math OCR: {e}" | |
| def get_response(self, user_input, image=None): | |
| """ | |
| Generate a response using chat completions with improved error handling | |
| Args: | |
| user_input (str): User's message | |
| image (optional): Uploaded image | |
| Returns: | |
| Stream of chat completions or error message | |
| """ | |
| try: | |
| # Prepare messages with conversation context and persistent memory | |
| messages = [] | |
| # Add system prompt as first message | |
| messages.append(ChatMessage( | |
| role="system", | |
| content=self.system_prompt | |
| ).to_dict()) | |
| # Add persistent memory context if available | |
| if self.persistent_memory: | |
| memory_context = "Remembered Information:\n" + "\n".join( | |
| [f"{k}: {v}" for k, v in self.persistent_memory.items()] | |
| ) | |
| messages.append(ChatMessage( | |
| role="system", | |
| content=memory_context | |
| ).to_dict()) | |
| # Convert existing conversation history to ChatMessage objects and then to dictionaries | |
| for msg in self.conversation_history: | |
| messages.append(ChatMessage( | |
| role=msg['role'], | |
| content=msg['content'] | |
| ).to_dict()) | |
| # Process image if uploaded | |
| if image: | |
| image_caption = self.caption_image(image) | |
| user_input = f"Uploaded image : {image_caption}\n\nUser's message: {user_input}" | |
| # Add user input | |
| messages.append(ChatMessage( | |
| role="user", | |
| content=user_input | |
| ).to_dict()) | |
| # Calculate available tokens | |
| input_tokens = sum(len(msg['content'].split()) for msg in messages) | |
| max_new_tokens = 16384 - input_tokens - 50 # Reserve some tokens for safety | |
| # Limit max_new_tokens to prevent exceeding the total limit | |
| max_new_tokens = min(max_new_tokens, 10020) | |
| # Generate response with streaming | |
| stream = self.client.chat_completion( | |
| messages=messages, | |
| model="Qwen/QwQ-32B-Preview", | |
| temperature=0.7, | |
| max_tokens=max_new_tokens, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| return stream | |
| except Exception as e: | |
| print(f"Detailed error in get_response: {e}") | |
| return f"Error generating response: {str(e)}" | |
| def messages_to_prompt(self, messages): | |
| """ | |
| Convert a list of ChatMessage dictionaries to a single prompt string. | |
| This is a simple implementation and you might need to adjust it | |
| based on the specific requirements of the model you are using. | |
| """ | |
| prompt = "" | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| prompt += f"<|system|>\n{msg['content']}<|end|>\n" | |
| elif msg["role"] == "user": | |
| prompt += f"<|user|>\n{msg['content']}<|end|>\n" | |
| elif msg["role"] == "assistant": | |
| prompt += f"<|assistant|>\n{msg['content']}<|end|>\n" | |
| prompt += "<|assistant|>\n" # Start of assistant's turn | |
| return prompt | |
| def create_interface(self): | |
| def streaming_response(message, chat_history, image_filepath, math_ocr_image_path): | |
| ocr_text = "" | |
| if math_ocr_image_path: | |
| ocr_text = self.perform_math_ocr(math_ocr_image_path) | |
| if ocr_text.startswith("Error"): | |
| # Handle OCR error | |
| updated_history = chat_history + [[message, ocr_text]] | |
| yield "", updated_history, None, None | |
| return | |
| else: | |
| message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message} | |
| # Check if an image was actually uploaded | |
| if image_filepath: | |
| response_stream = self.get_response(message, image_filepath) | |
| else: | |
| response_stream = self.get_response(message) | |
| # Handle errors in get_response | |
| if isinstance(response_stream, str): | |
| # Return immediately with the error message | |
| updated_history = chat_history + [[message, response_stream]] | |
| yield "", updated_history, None, None | |
| return | |
| # Prepare for streaming response | |
| full_response = "" | |
| updated_history = chat_history + [[message, ""]] | |
| # Streaming output | |
| try: | |
| for chunk in response_stream: | |
| if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
| chunk_content = chunk.choices[0].delta.content | |
| full_response += chunk_content | |
| # Update the last message in chat history with partial response | |
| updated_history[-1][1] = full_response | |
| yield "", updated_history, None, None | |
| except Exception as e: | |
| print(f"Streaming error: {e}") | |
| # Display error in the chat interface | |
| updated_history[-1][1] = f"Error during response: {e}" | |
| yield "", updated_history, None, None | |
| return | |
| # Update conversation history | |
| self.conversation_history.append( | |
| {"role": "user", "content": message} | |
| ) | |
| self.conversation_history.append( | |
| {"role": "assistant", "content": full_response} | |
| ) | |
| # Limit conversation history | |
| if len(self.conversation_history) > 10: | |
| self.conversation_history = self.conversation_history[-10:] | |
| # Custom CSS for Inter font and improved styling | |
| custom_css = """ | |
| body, .gradio-container { | |
| font-family: 'Inter', sans-serif !important; | |
| } | |
| .chatbot-container .message { | |
| font-family: 'Inter', sans-serif !important; | |
| } | |
| .gradio-container input, | |
| .gradio-container textarea, | |
| .gradio-container button { | |
| font-family: 'Inter', sans-serif !important; | |
| } | |
| /* Image Upload Styling */ | |
| .image-container { | |
| border: 1px solid #ccc; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 10px; | |
| background-color: #f8f8f8; | |
| } | |
| .image-preview { | |
| max-width: 200px; | |
| max-height: 200px; | |
| border-radius: 8px; | |
| } | |
| .image-buttons { | |
| display: flex; | |
| gap: 10px; | |
| } | |
| .image-buttons button { | |
| padding: 8px 15px; | |
| border-radius: 5px; | |
| background-color: #4CAF50; | |
| color: white; | |
| border: none; | |
| cursor: pointer; | |
| } | |
| .image-buttons button:hover { | |
| background-color: #367c39; | |
| } | |
| """ | |
| with gr.Blocks(theme='soft', css=custom_css) as demo: | |
| # Chat interface with improved styling | |
| with gr.Column(): | |
| chatbot = gr.Chatbot( | |
| label="Xylaria 1.5 Senoa (EXPERIMENTAL)", | |
| height=500, | |
| show_copy_button=True, | |
| ) | |
| # Enhanced Image Upload Section | |
| with gr.Accordion("Image Input", open=False): | |
| with gr.Column() as image_container: # Use a Column for the image container | |
| img = gr.Image( | |
| sources=["upload", "webcam"], | |
| type="filepath", | |
| label="", # Remove label as it's redundant | |
| elem_classes="image-preview", # Add a class for styling | |
| ) | |
| with gr.Row(): | |
| clear_image_btn = gr.Button("Clear Image") | |
| with gr.Accordion("Math Input", open=False): | |
| with gr.Column(): | |
| math_ocr_img = gr.Image( | |
| sources=["upload", "webcam"], | |
| type="filepath", | |
| label="Upload Image for math", | |
| elem_classes="image-preview" | |
| ) | |
| with gr.Row(): | |
| clear_math_ocr_btn = gr.Button("Clear Math Image") | |
| # Input row with improved layout | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your message...", | |
| container=False | |
| ) | |
| btn = gr.Button("Send", scale=1) | |
| # Clear history and memory buttons | |
| with gr.Row(): | |
| clear = gr.Button("Clear Conversation") | |
| clear_memory = gr.Button("Clear Memory") | |
| # Clear image functionality | |
| clear_image_btn.click( | |
| fn=lambda: None, | |
| inputs=None, | |
| outputs=[img], | |
| queue=False | |
| ) | |
| # Clear Math OCR image functionality | |
| clear_math_ocr_btn.click( | |
| fn=lambda: None, | |
| inputs=None, | |
| outputs=[math_ocr_img], | |
| queue=False | |
| ) | |
| # Submit functionality with streaming and image support | |
| btn.click( | |
| fn=streaming_response, | |
| inputs=[txt, chatbot, img, math_ocr_img], | |
| outputs=[txt, chatbot, img, math_ocr_img] | |
| ) | |
| txt.submit( | |
| fn=streaming_response, | |
| inputs=[txt, chatbot, img, math_ocr_img], | |
| outputs=[txt, chatbot, img, math_ocr_img] | |
| ) | |
| # Clear conversation history | |
| clear.click( | |
| fn=lambda: None, | |
| inputs=None, | |
| outputs=[chatbot], | |
| queue=False | |
| ) | |
| # Clear persistent memory and reset conversation | |
| clear_memory.click( | |
| fn=self.reset_conversation, | |
| inputs=None, | |
| outputs=[chatbot], | |
| queue=False | |
| ) | |
| # Ensure memory is cleared when the interface is closed | |
| demo.load(self.reset_conversation, None, None) | |
| return demo | |
| # Launch the interface | |
| def main(): | |
| chat = XylariaChat() | |
| interface = chat.create_interface() | |
| interface.launch( | |
| share=True, # Optional: create a public link | |
| debug=True # Show detailed errors | |
| ) | |
| if __name__ == "__main__": | |
| main() |