import os import pickle import numpy as np import streamlit as st from PIL import Image from tensorflow.keras.preprocessing import image from tensorflow.keras.layers import GlobalMaxPooling2D from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input from sklearn.neighbors import NearestNeighbors from numpy.linalg import norm from chatbot import Chatbot # Assuming you have a chatbot module import tensorflow as tf # Make sure this import is included # Define function for feature extraction def feature_extraction(img_path, model): img = image.load_img(img_path, target_size=(224, 224)) img_array = image.img_to_array(img) expanded_img_array = np.expand_dims(img_array, axis=0) preprocessed_img = preprocess_input(expanded_img_array) result = model.predict(preprocessed_img).flatten() normalized_result = result / norm(result) return normalized_result # Define function for recommendation def recommend(features, feature_list): neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean') neighbors.fit(feature_list) distances, indices = neighbors.kneighbors([features]) return indices # Function to save uploaded file def save_uploaded_file(uploaded_file): try: # Ensure the uploads directory exists if not os.path.exists('uploads'): os.makedirs('uploads') file_path = os.path.join('uploads', uploaded_file.name) with open(file_path, 'wb') as f: f.write(uploaded_file.getbuffer()) st.success(f"File saved to {file_path}") return file_path except Exception as e: st.error(f"Error saving file: {e}") return None # Function to show dashboard content def show_dashboard(): st.header("Fashion Recommender System") chatbot = Chatbot() # Load ResNet model for image feature extraction model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) model.trainable = False model = tf.keras.Sequential([ model, GlobalMaxPooling2D() ]) try: feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb'))) filenames = pickle.load(open('filenames.pkl', 'rb')) except Exception as e: st.error(f"Error loading pickle files: {e}") return # Print the filenames to verify st.write("List of filenames loaded:") st.write(filenames) # File upload section uploaded_file = st.file_uploader("Choose an image") if uploaded_file is not None: file_path = save_uploaded_file(uploaded_file) if file_path: # Display the uploaded image try: display_image = Image.open(file_path) st.image(display_image) except Exception as e: st.error(f"Error displaying uploaded image: {e}") # Feature extraction try: features = feature_extraction(file_path, model) except Exception as e: st.error(f"Error extracting features: {e}") return # Recommendation try: indices = recommend(features, feature_list) except Exception as e: st.error(f"Error in recommendation: {e}") return # Display recommended products col1, col2, col3, col4, col5 = st.columns(5) columns = [col1, col2, col3, col4, col5] for col, idx in zip(columns, indices[0]): # Directly access images from the dataset instead of file paths image_data = chatbot.images[idx] if image_data is not None: try: with col: st.image(image_data) except Exception as e: st.error(f"Error opening image index {idx}: {e}") else: st.error("Some error occurred in file upload") # Chatbot section user_question = st.text_input("Ask a question:") if user_question: bot_response, recommended_products = chatbot.generate_response(user_question) st.write("Chatbot:", bot_response) # Display recommended products for result in recommended_products: pid = result['corpus_id'] product_info = chatbot.product_data[pid] st.write("Product Name:", product_info['productDisplayName']) st.write("Category:", product_info['masterCategory']) st.write("Article Type:", product_info['articleType']) st.write("Usage:", product_info['usage']) st.write("Season:", product_info['season']) st.write("Gender:", product_info['gender']) st.image(chatbot.images[pid]) # Main Streamlit app def main(): # Give title to the app st.title("Fashion Recommender System") # Show dashboard content directly show_dashboard() # Run the main app if __name__ == "__main__": main()