Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
6523110 9eafc14 6523110 19ee29e 6523110 115de81 6523110 9eafc14 6523110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import json
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from datasets import load_dataset
import chromadb
from datetime import datetime
def initialize_collection(collection_name="clip_image_embeddings"):
# Initialize ChromaDB client (PersistentClient stores embeddings between runs)
client = chromadb.PersistentClient(path="./chroma_db") # Change path if needed
# Get a list of existing collection names
existing_collections = [col for col in client.list_collections()] # v0.6.0 change
if collection_name in existing_collections:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection: {collection_name}")
else:
collection = client.create_collection(name=collection_name)
return collection
def initialize_collection(collection_name="clip_image_embeddings"):
client = chromadb.PersistentClient(path="./chroma_db")
# Extract just the collection names
existing_collections = [col.name for col in client.list_collections()]
if collection_name in existing_collections:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection: {collection_name}")
else:
collection = client.create_collection(name=collection_name)
print(f"Created new collection: {collection_name}")
return collection
def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"):
documents = []
embeddings_all = df_emb["fashion_clip_image"].tolist()
for i in range(len(df_emb)):
if i >= len(df_emb["image_url"]):
print(f"Index {i} out of range for image_url column")
continue
image_url = df_emb["image_url"].iloc[i]
if image_url is not None:
documents.append(image_url)
else:
documents.append(None)
collection = initialize_collection(collection_name)
for i, d in enumerate(documents):
embeddings = embeddings_all[i]
embedding_id = str(i) # Convert index to string
if embeddings is None:
continue # Skip if no embedding
# Check if ID already exists
existing_entry = collection.get(ids=[embedding_id])
if existing_entry and existing_entry["ids"]: # If the ID is found, skip adding
continue
collection.add(
ids=[embedding_id],
embeddings=embeddings,
documents=[d]
)
return collection
model_name = "patrickjohncyh/fashion-clip"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
def main_text_retrieve_images(text, result_query=None, n_retrieved=3):
# Load the dataset (no split specified, so the whole dataset)
dataset = load_dataset("traopia/fashion_show_data_all_embeddings.json")
# This returns a DatasetDict with splits as keys (usually 'train' by default).
# To get the whole dataset, you can access the first split like this:
split_name = list(dataset.keys())[0]
full_dataset = dataset[split_name]
# Convert to pandas DataFrame
df_emb = full_dataset.to_pandas()
df_emb = df_emb.drop_duplicates(subset='image_urls')
df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None)
df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None)
df_emb = df_emb.drop_duplicates(subset='image_url')
#print("DataFrame head:", df_emb.head()) # Debugging statement
if result_query:
df_small = pd.DataFrame(result_query, columns=["image_url"])
df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner')
df_emb = df_filtered
# Generate a collection name based on the current timestamp
collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
collection = main_create_image_collection(df_emb, collection_name=collection_name)
else:
collection = main_create_image_collection(df_emb)
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
with torch.no_grad():
text_features = model.get_text_features(**inputs).cpu().numpy() # Should work without crashing
results = collection.query(
query_embeddings=text_features[0],
n_results=n_retrieved
)
#print(results)
result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"])
df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner')
# Remove columns fashion_clip_image, image_urls, and description
df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls'])
return df_result.to_dict(orient='records')
if __name__ == "__main__":
text = "dress"
result = main_text_retrieve_images(text)
print(result)
|