File size: 4,088 Bytes
6523110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

import json
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


import chromadb

from datetime import datetime

def initialize_collection(collection_name="clip_image_embeddings"):
    # Initialize ChromaDB client (PersistentClient stores embeddings between runs)
    client = chromadb.PersistentClient(path="./chroma_db")  # Change path if needed

    # Get a list of existing collection names
    existing_collections = [col for col in client.list_collections()]  # v0.6.0 change

    if collection_name in existing_collections:
        collection = client.get_collection(name=collection_name)
        print(f"Using existing collection: {collection_name}")
    else:
        collection = client.create_collection(name=collection_name)

    return collection

def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"):
    documents = []
    embeddings_all = df_emb["fashion_clip_image"].tolist()

    for i in range(len(df_emb)):
        if i >= len(df_emb["image_url"]):
            print(f"Index {i} out of range for image_url column")
            continue

        image_url = df_emb["image_url"].iloc[i]
        if image_url is not None:
            documents.append(image_url)
        else:
            documents.append(None)

    collection = initialize_collection(collection_name)

    for i, d in enumerate(documents):
        embeddings = embeddings_all[i]
        embedding_id = str(i)  # Convert index to string

        if embeddings is None:
            continue  # Skip if no embedding

        # Check if ID already exists
        existing_entry = collection.get(ids=[embedding_id])

        if existing_entry and existing_entry["ids"]:  # If the ID is found, skip adding
            continue

        collection.add(
            ids=[embedding_id],
            embeddings=embeddings,
            documents=[d]
        )

    return collection



model_name = "patrickjohncyh/fashion-clip"
device = "mps"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

def main_text_retrieve_images(text, result_query=None, n_retrieved=3):
    df_emb = pd.read_json("data/fashion_show_data_all_embeddings.json", lines=True)
    df_emb = df_emb.drop_duplicates(subset='image_urls')
    df_emb['fashion_clip_image'] = df_emb['fashion_clip_image'].apply(lambda x: x[0] if type(x) == list else None)
    df_emb['image_url'] = df_emb['image_urls'].apply(lambda x: x[0] if x else None)
    df_emb = df_emb.drop_duplicates(subset='image_url')

    #print("DataFrame head:", df_emb.head())  # Debugging statement

    if result_query:
        df_small = pd.DataFrame(result_query, columns=["image_url"])
        df_filtered = df_emb.merge(df_small[['image_url']], on='image_url', how='inner')
        df_emb = df_filtered
        # Generate a collection name based on the current timestamp
        collection_name = f"collection_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        collection = main_create_image_collection(df_emb, collection_name=collection_name)
    else:
        collection = main_create_image_collection(df_emb)

    inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)

    with torch.no_grad():
        text_features = model.get_text_features(**inputs).cpu().numpy()  # Should work without crashing

    results = collection.query(
        query_embeddings=text_features[0],
        n_results=n_retrieved
    )
    #print(results)
    result_doc = pd.DataFrame(results['documents'][0], columns=["image_url"])
    df_result = df_emb.merge(result_doc[['image_url']], on='image_url', how='inner')
    # Remove columns fashion_clip_image, image_urls, and description
    df_result = df_result.drop(columns=['fashion_clip_image', 'description', 'editor', 'publish_date', 'image_urls'])
    return df_result.to_dict(orient='records')

if __name__ == "__main__":
    text = "dress"
    result = main_text_retrieve_images(text)
    print(result)