File size: 5,140 Bytes
6523110
 
 
 
 
 
 
 
9eafc14
6523110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19ee29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6523110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115de81
6523110
 
 
 
9eafc14
 
 
 
 
 
 
 
 
 
 
 
6523110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

import json
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from datasets import load_dataset
import chromadb

from datetime import datetime

def initialize_collection(collection_name="clip_image_embeddings"):
    # Initialize ChromaDB client (PersistentClient stores embeddings between runs)
    client = chromadb.PersistentClient(path="./chroma_db")  # Change path if needed

    # Get a list of existing collection names
    existing_collections = [col for col in client.list_collections()]  # v0.6.0 change

    if collection_name in existing_collections:
        collection = client.get_collection(name=collection_name)
        print(f"Using existing collection: {collection_name}")
    else:
        collection = client.create_collection(name=collection_name)

    return collection

def initialize_collection(collection_name="clip_image_embeddings"):
    client = chromadb.PersistentClient(path="./chroma_db")

    # Extract just the collection names
    existing_collections = [col.name for col in client.list_collections()]

    if collection_name in existing_collections:
        collection = client.get_collection(name=collection_name)
        print(f"Using existing collection: {collection_name}")
    else:
        collection = client.create_collection(name=collection_name)
        print(f"Created new collection: {collection_name}")

    return collection

def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"):
    documents = []
    embeddings_all = df_emb["fashion_clip_image"].tolist()

    for i in range(len(df_emb)):
        if i >= len(df_emb["image_url"]):
            print(f"Index {i} out of range for image_url column")
            continue

        image_url = df_emb["image_url"].iloc[i]
        if image_url is not None:
            documents.append(image_url)
        else:
            documents.append(None)

    collection = initialize_collection(collection_name)

    for i, d in enumerate(documents):
        embeddings = embeddings_all[i]
        embedding_id = str(i)  # Convert index to string

        if embeddings is None:
            continue  # Skip if no embedding

        # Check if ID already exists
        existing_entry = collection.get(ids=[embedding_id])

        if existing_entry and existing_entry["ids"]:  # If the ID is found, skip adding
            continue

        collection.add(
            ids=[embedding_id],
            embeddings=embeddings,
            documents=[d]
        )

    return collection



model_name = "patrickjohncyh/fashion-clip"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

def main_text_retrieve_images(text, result_query=None, n_retrieved=3):

    

    # Load the dataset (no split specified, so the whole dataset)
    dataset = load_dataset("traopia/fashion_show_data_all_embeddings.json")
    # This returns a DatasetDict with splits as keys (usually 'train' by default).
    # To get the whole dataset, you can access the first split like this:
    split_name = list(dataset.keys())[0]
    full_dataset = dataset[split_name]

    # Convert to pandas DataFrame
    df_emb = full_dataset.to_pandas()
    df_emb = df_emb.drop_duplicates(subset='image_urls')
    df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None)
    df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None)
    df_emb = df_emb.drop_duplicates(subset='image_url')

    #print("DataFrame head:", df_emb.head())  # Debugging statement

    if result_query:
        df_small = pd.DataFrame(result_query, columns=["image_url"])
        df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner')
        df_emb = df_filtered
        # Generate a collection name based on the current timestamp
        collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        collection = main_create_image_collection(df_emb, collection_name=collection_name)
    else:
        collection = main_create_image_collection(df_emb)

    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)

    with torch.no_grad():
        text_features = model.get_text_features(**inputs).cpu().numpy()  # Should work without crashing

    results = collection.query(
        query_embeddings=text_features[0],
        n_results=n_retrieved
    )
    #print(results)
    result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"])
    df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner')
    # Remove columns fashion_clip_image, image_urls, and description
    df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls'])
    return df_result.to_dict(orient='records')

if __name__ == "__main__":
    text = "dress"
    result = main_text_retrieve_images(text)
    print(result)