Spaces:
No application file
No application file
| import chromadb | |
| from PIL import Image as PILImage | |
| import streamlit as st | |
| import os | |
| from utils.qa import chain | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain_community.chat_message_histories import StreamlitChatMessageHistory | |
| import base64 | |
| import io | |
| # Initialize Chromadb client | |
| path = "mm_vdb2" | |
| client = chromadb.PersistentClient(path=path) | |
| image_collection = client.get_collection(name="image") | |
| video_collection = client.get_collection(name='video_collection') | |
| # Set up memory storage for the chat | |
| memory_storage = StreamlitChatMessageHistory(key="chat_messages") | |
| memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3) | |
| # Function to get an answer from the chain | |
| def get_answer(query): | |
| response = chain.invoke(query) | |
| return response.get("result", "No result found.") | |
| # Function to display images in the UI | |
| def display_images(image_collection, query_text, max_distance=None, debug=False): | |
| results = image_collection.query( | |
| query_texts=[query_text], | |
| n_results=10, | |
| include=['uris', 'distances'] | |
| ) | |
| uris = results['uris'][0] | |
| distances = results['distances'][0] | |
| sorted_results = sorted(zip(uris, distances), key=lambda x: x[0]) | |
| cols = st.columns(3) | |
| for i, (uri, distance) in enumerate(sorted_results): | |
| if max_distance is None or distance <= max_distance: | |
| try: | |
| img = PILImage.open(uri) | |
| with cols[i % 3]: | |
| st.image(img, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| # Function to display videos in the UI | |
| def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False): | |
| displayed_videos = set() | |
| results = video_collection.query( | |
| query_texts=[query_text], | |
| n_results=max_results, | |
| include=['uris', 'distances', 'metadatas'] | |
| ) | |
| uris = results['uris'][0] | |
| distances = results['distances'][0] | |
| metadatas = results['metadatas'][0] | |
| for uri, distance, metadata in zip(uris, distances, metadatas): | |
| video_uri = metadata['video_uri'] | |
| if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos: | |
| if debug: | |
| st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}") | |
| st.video(video_uri) | |
| displayed_videos.add(video_uri) | |
| else: | |
| if debug: | |
| st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)") | |
| # Function to format the inputs for image and video processing | |
| def format_prompt_inputs(image_collection, video_collection, user_query): | |
| frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55) | |
| image_candidates = image_uris(image_collection, user_query, max_distance=1.5) | |
| inputs = {"query": user_query} | |
| frame = frame_candidates[0] if frame_candidates else "" | |
| inputs["frame"] = frame | |
| if image_candidates: | |
| image = image_candidates[0] | |
| with PILImage.open(image) as img: | |
| img = img.resize((img.width // 6, img.height // 6)) | |
| img = img.convert("L") | |
| with io.BytesIO() as output: | |
| img.save(output, format="JPEG", quality=60) | |
| compressed_image_data = output.getvalue() | |
| inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8') | |
| else: | |
| inputs["image_data_1"] = "" | |
| return inputs | |
| # Main function to initialize and run the UI | |
| def home(): | |
| # Set up the page layout | |
| st.set_page_config(layout='wide', page_title="Virtual Tutor") | |
| # Header | |
| st.header("Welcome to Virtual Tutor - CHAT") | |
| # SVG Banner for UI branding | |
| st.markdown(""" | |
| <svg width="600" height="100"> | |
| <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white" | |
| stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT | |
| </text> | |
| </svg> | |
| """, unsafe_allow_html=True) | |
| # Initialize the chat session if not already initialized | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}] | |
| # Styling for the chat input container | |
| st.markdown(""" | |
| <style> | |
| .stChatInputContainer > div { | |
| background-color: #000000; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Display previous chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # Display chat messages from memory | |
| for i, msg in enumerate(memory_storage.messages): | |
| name = "user" if i % 2 == 0 else "assistant" | |
| st.chat_message(name).markdown(msg.content) | |
| # Handle user input and generate response | |
| if user_input := st.chat_input("Enter your question here..."): | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.spinner("Generating Response..."): | |
| with st.chat_message("assistant"): | |
| response = get_answer(user_input) | |
| answer = response | |
| st.markdown(answer) | |
| # Save user and assistant messages to session state | |
| message = {"role": "assistant", "content": answer} | |
| message_u = {"role": "user", "content": user_input} | |
| st.session_state.messages.append(message_u) | |
| st.session_state.messages.append(message) | |
| # Process inputs for image/video | |
| inputs = format_prompt_inputs(image_collection, video_collection, user_input) | |
| # Display images | |
| st.markdown("### Images") | |
| display_images(image_collection, user_input, max_distance=1.55, debug=False) | |
| # Display videos based on frames | |
| st.markdown("### Videos") | |
| frame = inputs["frame"] | |
| if frame: | |
| directory_name = frame.split('/')[1] | |
| video_path = f"videos_flattened/{directory_name}.mp4" | |
| if os.path.exists(video_path): | |
| st.video(video_path) | |
| else: | |
| st.error("Video file not found.") | |
| # Call the home function to run the app | |
| if __name__ == "__main__": | |
| home() | |