import json import pandas as pd from transformers import CLIPProcessor, CLIPModel import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" from datasets import load_dataset import chromadb from datetime import datetime def initialize_collection(collection_name="clip_image_embeddings"): # Initialize ChromaDB client (PersistentClient stores embeddings between runs) client = chromadb.PersistentClient(path="./chroma_db") # Change path if needed # Get a list of existing collection names existing_collections = [col for col in client.list_collections()] # v0.6.0 change if collection_name in existing_collections: collection = client.get_collection(name=collection_name) print(f"Using existing collection: {collection_name}") else: collection = client.create_collection(name=collection_name) return collection def initialize_collection(collection_name="clip_image_embeddings"): client = chromadb.PersistentClient(path="./chroma_db") # Extract just the collection names existing_collections = [col.name for col in client.list_collections()] if collection_name in existing_collections: collection = client.get_collection(name=collection_name) print(f"Using existing collection: {collection_name}") else: collection = client.create_collection(name=collection_name) print(f"Created new collection: {collection_name}") return collection def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"): documents = [] embeddings_all = df_emb["fashion_clip_image"].tolist() for i in range(len(df_emb)): if i >= len(df_emb["image_url"]): print(f"Index {i} out of range for image_url column") continue image_url = df_emb["image_url"].iloc[i] if image_url is not None: documents.append(image_url) else: documents.append(None) collection = initialize_collection(collection_name) for i, d in enumerate(documents): embeddings = embeddings_all[i] embedding_id = str(i) # Convert index to string if embeddings is None: continue # Skip if no embedding # Check if ID already exists existing_entry = collection.get(ids=[embedding_id]) if existing_entry and existing_entry["ids"]: # If the ID is found, skip adding continue collection.add( ids=[embedding_id], embeddings=embeddings, documents=[d] ) return collection model_name = "patrickjohncyh/fashion-clip" device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model = CLIPModel.from_pretrained(model_name).to(device) processor = CLIPProcessor.from_pretrained(model_name) def main_text_retrieve_images(text, result_query=None, n_retrieved=3): # Load the dataset (no split specified, so the whole dataset) dataset = load_dataset("traopia/fashion_show_data_all_embeddings.json") # This returns a DatasetDict with splits as keys (usually 'train' by default). # To get the whole dataset, you can access the first split like this: split_name = list(dataset.keys())[0] full_dataset = dataset[split_name] # Convert to pandas DataFrame df_emb = full_dataset.to_pandas() df_emb = df_emb.drop_duplicates(subset='image_urls') df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None) df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None) df_emb = df_emb.drop_duplicates(subset='image_url') #print("DataFrame head:", df_emb.head()) # Debugging statement if result_query: df_small = pd.DataFrame(result_query, columns=["image_url"]) df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner') df_emb = df_filtered # Generate a collection name based on the current timestamp collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}" collection = main_create_image_collection(df_emb, collection_name=collection_name) else: collection = main_create_image_collection(df_emb) inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device) with torch.no_grad(): text_features = model.get_text_features(**inputs).cpu().numpy() # Should work without crashing results = collection.query( query_embeddings=text_features[0], n_results=n_retrieved ) #print(results) result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"]) df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner') # Remove columns fashion_clip_image, image_urls, and description df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls']) return df_result.to_dict(orient='records') if __name__ == "__main__": text = "dress" result = main_text_retrieve_images(text) print(result)