Ask-FashionDB / src /visual_qa.py
traopia
name chroma
19ee29e
raw
history blame
5.14 kB
import json
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from datasets import load_dataset
import chromadb
from datetime import datetime
def initialize_collection(collection_name="clip_image_embeddings"):
# Initialize ChromaDB client (PersistentClient stores embeddings between runs)
client = chromadb.PersistentClient(path="./chroma_db") # Change path if needed
# Get a list of existing collection names
existing_collections = [col for col in client.list_collections()] # v0.6.0 change
if collection_name in existing_collections:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection: {collection_name}")
else:
collection = client.create_collection(name=collection_name)
return collection
def initialize_collection(collection_name="clip_image_embeddings"):
client = chromadb.PersistentClient(path="./chroma_db")
# Extract just the collection names
existing_collections = [col.name for col in client.list_collections()]
if collection_name in existing_collections:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection: {collection_name}")
else:
collection = client.create_collection(name=collection_name)
print(f"Created new collection: {collection_name}")
return collection
def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"):
documents = []
embeddings_all = df_emb["fashion_clip_image"].tolist()
for i in range(len(df_emb)):
if i >= len(df_emb["image_url"]):
print(f"Index {i} out of range for image_url column")
continue
image_url = df_emb["image_url"].iloc[i]
if image_url is not None:
documents.append(image_url)
else:
documents.append(None)
collection = initialize_collection(collection_name)
for i, d in enumerate(documents):
embeddings = embeddings_all[i]
embedding_id = str(i) # Convert index to string
if embeddings is None:
continue # Skip if no embedding
# Check if ID already exists
existing_entry = collection.get(ids=[embedding_id])
if existing_entry and existing_entry["ids"]: # If the ID is found, skip adding
continue
collection.add(
ids=[embedding_id],
embeddings=embeddings,
documents=[d]
)
return collection
model_name = "patrickjohncyh/fashion-clip"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
def main_text_retrieve_images(text, result_query=None, n_retrieved=3):
# Load the dataset (no split specified, so the whole dataset)
dataset = load_dataset("traopia/fashion_show_data_all_embeddings.json")
# This returns a DatasetDict with splits as keys (usually 'train' by default).
# To get the whole dataset, you can access the first split like this:
split_name = list(dataset.keys())[0]
full_dataset = dataset[split_name]
# Convert to pandas DataFrame
df_emb = full_dataset.to_pandas()
df_emb = df_emb.drop_duplicates(subset='image_urls')
df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None)
df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None)
df_emb = df_emb.drop_duplicates(subset='image_url')
#print("DataFrame head:", df_emb.head()) # Debugging statement
if result_query:
df_small = pd.DataFrame(result_query, columns=["image_url"])
df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner')
df_emb = df_filtered
# Generate a collection name based on the current timestamp
collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
collection = main_create_image_collection(df_emb, collection_name=collection_name)
else:
collection = main_create_image_collection(df_emb)
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
with torch.no_grad():
text_features = model.get_text_features(**inputs).cpu().numpy() # Should work without crashing
results = collection.query(
query_embeddings=text_features[0],
n_results=n_retrieved
)
#print(results)
result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"])
df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner')
# Remove columns fashion_clip_image, image_urls, and description
df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls'])
return df_result.to_dict(orient='records')
if __name__ == "__main__":
text = "dress"
result = main_text_retrieve_images(text)
print(result)