Spaces:
Sleeping
Sleeping
import json | |
import pandas as pd | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
from datasets import load_dataset | |
import chromadb | |
from datetime import datetime | |
def initialize_collection(collection_name="clip_image_embeddings"): | |
# Initialize ChromaDB client (PersistentClient stores embeddings between runs) | |
client = chromadb.PersistentClient(path="./chroma_db") # Change path if needed | |
# Get a list of existing collection names | |
existing_collections = [col for col in client.list_collections()] # v0.6.0 change | |
if collection_name in existing_collections: | |
collection = client.get_collection(name=collection_name) | |
print(f"Using existing collection: {collection_name}") | |
else: | |
collection = client.create_collection(name=collection_name) | |
return collection | |
def initialize_collection(collection_name="clip_image_embeddings"): | |
client = chromadb.PersistentClient(path="./chroma_db") | |
# Extract just the collection names | |
existing_collections = [col.name for col in client.list_collections()] | |
if collection_name in existing_collections: | |
collection = client.get_collection(name=collection_name) | |
print(f"Using existing collection: {collection_name}") | |
else: | |
collection = client.create_collection(name=collection_name) | |
print(f"Created new collection: {collection_name}") | |
return collection | |
def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"): | |
documents = [] | |
embeddings_all = df_emb["fashion_clip_image"].tolist() | |
for i in range(len(df_emb)): | |
if i >= len(df_emb["image_url"]): | |
print(f"Index {i} out of range for image_url column") | |
continue | |
image_url = df_emb["image_url"].iloc[i] | |
if image_url is not None: | |
documents.append(image_url) | |
else: | |
documents.append(None) | |
collection = initialize_collection(collection_name) | |
for i, d in enumerate(documents): | |
embeddings = embeddings_all[i] | |
embedding_id = str(i) # Convert index to string | |
if embeddings is None: | |
continue # Skip if no embedding | |
# Check if ID already exists | |
existing_entry = collection.get(ids=[embedding_id]) | |
if existing_entry and existing_entry["ids"]: # If the ID is found, skip adding | |
continue | |
collection.add( | |
ids=[embedding_id], | |
embeddings=embeddings, | |
documents=[d] | |
) | |
return collection | |
model_name = "patrickjohncyh/fashion-clip" | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
model = CLIPModel.from_pretrained(model_name).to(device) | |
processor = CLIPProcessor.from_pretrained(model_name) | |
def main_text_retrieve_images(text, result_query=None, n_retrieved=3): | |
# Load the dataset (no split specified, so the whole dataset) | |
dataset = load_dataset("traopia/fashion_show_data_all_embeddings.json") | |
# This returns a DatasetDict with splits as keys (usually 'train' by default). | |
# To get the whole dataset, you can access the first split like this: | |
split_name = list(dataset.keys())[0] | |
full_dataset = dataset[split_name] | |
# Convert to pandas DataFrame | |
df_emb = full_dataset.to_pandas() | |
df_emb = df_emb.drop_duplicates(subset='image_urls') | |
df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None) | |
df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None) | |
df_emb = df_emb.drop_duplicates(subset='image_url') | |
#print("DataFrame head:", df_emb.head()) # Debugging statement | |
if result_query: | |
df_small = pd.DataFrame(result_query, columns=["image_url"]) | |
df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner') | |
df_emb = df_filtered | |
# Generate a collection name based on the current timestamp | |
collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
collection = main_create_image_collection(df_emb, collection_name=collection_name) | |
else: | |
collection = main_create_image_collection(df_emb) | |
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device) | |
with torch.no_grad(): | |
text_features = model.get_text_features(**inputs).cpu().numpy() # Should work without crashing | |
results = collection.query( | |
query_embeddings=text_features[0], | |
n_results=n_retrieved | |
) | |
#print(results) | |
result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"]) | |
df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner') | |
# Remove columns fashion_clip_image, image_urls, and description | |
df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls']) | |
return df_result.to_dict(orient='records') | |
if __name__ == "__main__": | |
text = "dress" | |
result = main_text_retrieve_images(text) | |
print(result) | |