|
|
|
import os |
|
import pickle |
|
import numpy as np |
|
from typing import Dict, Any, List |
|
from dotenv import load_dotenv |
|
from tqdm import tqdm |
|
from itertools import combinations |
|
|
|
from scipy import sparse |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from mcp.server.fastmcp import FastMCP |
|
from mcp.server.sse import SseServerTransport |
|
from starlette.applications import Starlette |
|
from starlette.routing import Route, Mount |
|
import uvicorn |
|
import pandas as pd |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
from openai import AsyncOpenAI |
|
|
|
|
|
load_dotenv() |
|
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT", "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") |
|
openai = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE) |
|
|
|
|
|
|
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True) |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True) |
|
clip_model.eval() |
|
|
|
|
|
|
|
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet").set_index("item_id") |
|
outfits_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/outfits_lite.parquet").set_index("outfit_id") |
|
users_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/users_lite.parquet").set_index("user_id") |
|
image_paths = items_df["path"].to_dict() |
|
|
|
|
|
class InteractionDataManager: |
|
def __init__(self, users_df, outfits_df, items_df): |
|
""" |
|
初始化类,加载数据并设置基本参数 |
|
|
|
参数: |
|
- users_file: 用户数据文件路径 (parquet) |
|
- outfits_file: Outfit 数据文件路径 (parquet) |
|
- items_file: 单品数据文件路径 (parquet) |
|
""" |
|
self.users_df = users_df |
|
self.outfits_df = outfits_df |
|
self.items_df = items_df |
|
|
|
|
|
self.item_id_to_index = {item_id: index for index, item_id in enumerate(self.items_df.index)} |
|
self.index_to_item_id = {index: item_id for index, item_id in enumerate(self.items_df.index)} |
|
self.user_id_to_index = {user_id: index for index, user_id in enumerate(self.users_df.index)} |
|
self.index_to_user_id = {index: user_id for index, user_id in enumerate(self.users_df.index)} |
|
self.outfit_ids_dict = self.outfits_df['item_ids'].to_dict() |
|
self.item_category_dict = self.items_df['category'].to_dict() |
|
self.item_subcategory_dict = self.items_df['subcategory'].to_dict() |
|
self.n_items = len(self.items_df) |
|
self.n_users = len(self.users_df) |
|
|
|
self.user_outfit_pairs = [] |
|
outfit_set = set(self.outfits_df.index) |
|
for uid, user in self.users_df.iterrows(): |
|
oids = user.outfit_ids.split(",") |
|
self.user_outfit_pairs.extend([(uid, oid) for oid in oids if oid in outfit_set]) |
|
|
|
|
|
self.subcategory_to_items = self.items_df.groupby('subcategory').apply(lambda x: set(x.index)).to_dict() |
|
|
|
|
|
self.subcategory_to_indices = {} |
|
for subcategory, item_ids in self.subcategory_to_items.items(): |
|
self.subcategory_to_indices[subcategory] = set([self.item_id_to_index[item_id] |
|
for item_id in item_ids |
|
if item_id in self.item_id_to_index]) |
|
|
|
item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/item_matrix.npz' |
|
try: |
|
self.load_matrix('item', item_interaction_matrix_path) |
|
except FileNotFoundError: |
|
self.build_item_interaction_matrix() |
|
self.save_matrix('item', item_interaction_matrix_path) |
|
|
|
user_item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/user_item_matrix.npz' |
|
try: |
|
self.load_matrix('user_item', user_item_interaction_matrix_path) |
|
except FileNotFoundError: |
|
self.build_user_item_interaction_matrix() |
|
self.save_matrix('user_item', user_item_interaction_matrix_path) |
|
|
|
|
|
with open(f"{FASHION_DATA_ROOT}/meta/clip_features.pkl", "rb") as f: |
|
print("Loading Fashion Features...") |
|
self.clip_features = pickle.load(f) |
|
print("Loading Fashion Features Successfully") |
|
|
|
|
|
self.item_ids = list(self.clip_features.keys()) |
|
self.image_embeddings = np.array([self.clip_features[item_id]["image_embeds"] for item_id in item_ids]) |
|
|
|
def save_matrix(self, matrix_type, filepath): |
|
""" |
|
保存矩阵到文件 |
|
|
|
参数: |
|
- matrix_type: 'item' 或 'user_item',指定保存的矩阵类型 |
|
- filepath: 保存路径 (例如 'temp/item_matrix.npz') |
|
""" |
|
if matrix_type == 'item': |
|
matrix = self.item_interaction_matrix |
|
elif matrix_type == 'user_item': |
|
matrix = self.user_item_interaction_matrix |
|
else: |
|
raise ValueError("matrix_type must be 'item' or 'user_item'") |
|
|
|
if matrix is None: |
|
raise ValueError(f"{matrix_type} matrix has not been built yet.") |
|
|
|
sparse.save_npz(filepath, matrix) |
|
print(f"Saved {matrix_type} matrix to {filepath}") |
|
|
|
def load_matrix(self, matrix_type, filepath): |
|
""" |
|
从文件加载矩阵 |
|
|
|
参数: |
|
- matrix_type: 'item' 或 'user_item',指定加载的矩阵类型 |
|
- filepath: 加载路径 (例如 'temp/item_matrix.npz') |
|
""" |
|
if not os.path.exists(filepath): |
|
raise FileNotFoundError(f"File {filepath} does not exist.") |
|
|
|
matrix = sparse.load_npz(filepath) |
|
if matrix_type == 'item': |
|
self.item_interaction_matrix = matrix |
|
elif matrix_type == 'user_item': |
|
self.user_item_interaction_matrix = matrix |
|
else: |
|
raise ValueError("matrix_type must be 'item' or 'user_item'") |
|
|
|
print(f"Loaded {matrix_type} matrix from {filepath}") |
|
return matrix |
|
|
|
def build_item_interaction_matrix(self): |
|
"""构建 Item-Item 交互矩阵""" |
|
|
|
self.item_interaction_matrix = sparse.lil_matrix((self.n_items, self.n_items), dtype=int) |
|
|
|
for index, outfit in tqdm(self.outfits_df.iterrows(), total=len(self.outfits_df)): |
|
item_ids = outfit['item_ids'].split(',') |
|
|
|
for item_id1, item_id2 in combinations(item_ids, r=2): |
|
if item_id1 in self.item_id_to_index and item_id2 in self.item_id_to_index: |
|
idx1 = self.item_id_to_index[item_id1] |
|
idx2 = self.item_id_to_index[item_id2] |
|
self.item_interaction_matrix[idx1, idx2] += 1 |
|
self.item_interaction_matrix[idx2, idx1] += 1 |
|
|
|
|
|
self.item_interaction_matrix = self.item_interaction_matrix.tocsr() |
|
return self.item_interaction_matrix |
|
|
|
def build_user_item_interaction_matrix(self): |
|
"""构建 User-Item 交互矩阵""" |
|
|
|
self.user_item_interaction_matrix = sparse.lil_matrix((self.n_users, self.n_items), dtype=int) |
|
|
|
for uid, user in tqdm(self.users_df.iterrows(), total=len(self.users_df)): |
|
oids = user["outfit_ids"].split(",") |
|
outfits = self.outfits_df.loc[self.outfits_df.index.isin(oids)] |
|
for oid, outfit in outfits.iterrows(): |
|
item_ids = outfit['item_ids'].split(',') |
|
|
|
for iid in item_ids: |
|
if iid in self.item_id_to_index: |
|
uidx = self.user_id_to_index[uid] |
|
iidx = self.item_id_to_index[iid] |
|
self.user_item_interaction_matrix[uidx, iidx] += 1 |
|
|
|
|
|
self.user_item_interaction_matrix = self.user_item_interaction_matrix.tocsr() |
|
return self.user_item_interaction_matrix |
|
|
|
def _process_interactions_for_category( |
|
self, |
|
matrix, |
|
given_id, |
|
category_indices, |
|
id_to_index |
|
): |
|
""" |
|
处理单个实体与目标类别的交互 |
|
|
|
参数: |
|
- matrix: 交互矩阵 |
|
- given_id: 给定的实体ID(用户或物品) |
|
- category_indices: 目标类别的物品索引集合 |
|
|
|
返回: |
|
- 交互列表,每个元素为一个包含item_id、interaction_count和score的字典 |
|
""" |
|
interactions = [] |
|
|
|
given_index = id_to_index[given_id] |
|
row = matrix[given_index] |
|
|
|
|
|
row_start = row.indptr[0] |
|
row_end = row.indptr[1] |
|
col_indices = row.indices[row_start:row_end] |
|
data_values = row.data[row_start:row_end] |
|
|
|
|
|
for col_idx, value in zip(col_indices, data_values): |
|
|
|
if col_idx in category_indices: |
|
|
|
output_id = self.index_to_item_id[col_idx] |
|
interactions.append({ |
|
'item_id': output_id, |
|
'interaction_count': int(value), |
|
'score': 0.0 |
|
}) |
|
|
|
return interactions |
|
|
|
def get_item_category_interactions( |
|
self, |
|
target_category: str, |
|
given_ids: List[str], |
|
query_type='item', |
|
top_k=None, |
|
): |
|
""" |
|
获取指定实体(用户或单品)与目标类别的所有交互情况 |
|
|
|
参数: |
|
- target_category: 待查询的subcategory |
|
- given_ids: List of 目标类别 |
|
- query_type: 查询的类别, item或user |
|
- top_k: 返回交互次数最多的前k个物品, 如果是None直接全部返回 |
|
|
|
返回: |
|
- 列表,包含与目标类别的交互统计信息,按交互次数排序 |
|
""" |
|
if query_type == 'item': |
|
matrix = self.item_interaction_matrix |
|
id_to_index = self.item_id_to_index |
|
elif query_type == 'user': |
|
matrix = self.user_item_interaction_matrix |
|
id_to_index = self.user_id_to_index |
|
else: |
|
print(f'query_type must be either item or user but got {query_type}') |
|
return [] |
|
|
|
|
|
all_interactions = [] |
|
category = target_category |
|
category_indices = self.subcategory_to_indices.get(category, set()) |
|
|
|
|
|
for given_id in given_ids: |
|
interactions = self._process_interactions_for_category( |
|
matrix, given_id, category_indices, id_to_index |
|
) |
|
|
|
all_interactions.extend(interactions) |
|
|
|
|
|
item_interactions = {} |
|
for interaction in all_interactions: |
|
item_id = interaction['item_id'] |
|
count = interaction['interaction_count'] |
|
|
|
if item_id in item_interactions: |
|
item_interactions[item_id] += count |
|
else: |
|
item_interactions[item_id] = count |
|
|
|
|
|
merged_interactions = [ |
|
{'item_id': item_id, 'interaction_count': count, 'score': 0.0} |
|
for item_id, count in item_interactions.items() |
|
] |
|
|
|
|
|
if merged_interactions: |
|
merged_interactions.sort(key=lambda x: x['interaction_count'], reverse=True) |
|
|
|
|
|
if top_k and merged_interactions: |
|
merged_interactions = merged_interactions[:top_k] |
|
|
|
|
|
return merged_interactions |
|
|
|
def rank_by_similarity(self, item_interactions, user_interactions, beta=2.0): |
|
""" |
|
计算用户交互项与商品交互项的相似度并排序 |
|
""" |
|
|
|
def get_combined_features(feature_dict): |
|
return (feature_dict['image_embeds'] + feature_dict['text_embeds']) / 2 |
|
|
|
item_feature_list = [] |
|
for item in item_interactions: |
|
item_id = item['item_id'] |
|
if item_id not in self.clip_features: |
|
raise ValueError(f"Didn't find clip feature of item with id: {item_id}") |
|
|
|
item_features = get_combined_features(self.clip_features[item_id]) |
|
item_feature_list.append(item_features) |
|
|
|
weights = np.array([x['interaction_count'] for x in item_interactions], dtype=np.float32) |
|
weights = weights / np.sum(weights) |
|
item_feature = np.sum(np.stack(item_feature_list, axis=0) * weights[:, np.newaxis], axis=0).reshape(1, -1) |
|
|
|
max_count = max((user_item.get('interaction_count', 1) for user_item in user_interactions), default=1) |
|
for user_item in user_interactions: |
|
user_item_id = user_item['item_id'] |
|
if user_item_id not in self.clip_features: |
|
raise ValueError(f"Didn't find clip feature of item with id: {user_item_id}") |
|
|
|
user_item_features = get_combined_features(self.clip_features[user_item_id]).reshape(1, -1) |
|
similarity = cosine_similarity(user_item_features, item_feature).item() |
|
interaction_count = user_item['interaction_count'] |
|
count_factor = (interaction_count / max_count) * beta + 1 |
|
user_item['score'] = float(similarity) * count_factor |
|
|
|
user_interactions.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
return user_interactions |
|
|
|
|
|
data_manager = InteractionDataManager(users_df, outfits_df, items_df) |
|
mcp = FastMCP('image-retrieval-server') |
|
|
|
|
|
@mcp.tool() |
|
async def summary_user_history(user_id: str, target_category: str, list_of_items: List[str]) -> str: |
|
"""Summary user's buying history of specific fashion category given user_id, target_category, list_of_items |
|
After we collect all buying history of this user, we will summarize descriptions of these historical items through LLM. |
|
So we will return user's preference about target_category in sentences. |
|
|
|
Args: |
|
user_id (str): User id. Will be provided through prompt |
|
target_category (str): We care about user's buying history of this specific category. |
|
list_of_items: List of item ids for history filtering. Will be provided through prompt |
|
""" |
|
|
|
|
|
|
|
item_interaction_result = data_manager.get_item_category_interactions( |
|
target_category, list_of_items, query_type='item' |
|
) |
|
user_interaction_result = data_manager.get_item_category_interactions( |
|
target_category, [user_id], query_type='user' |
|
) |
|
|
|
def get_description(item_id: str) -> str: |
|
return data_manager.items_df.loc[item_id].gen_description |
|
|
|
descriptions_for_summary = [] |
|
if len(item_interaction_result) == 0: |
|
descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result] |
|
else: |
|
if len(user_interaction_result) >= 0: |
|
user_interaction_result = data_manager.rank_by_similarity( |
|
item_interaction_result, |
|
user_interaction_result |
|
) |
|
descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result[:5]] |
|
|
|
if descriptions_for_summary: |
|
user_message = f"Summary user's preference of {target_category} based on following descriptions of fashion items that user brought previously:" |
|
for x in descriptions_for_summary: |
|
user_message += f"\n{x}" |
|
|
|
response = await openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": f"You are a user preference summary assistant. Your response is limited in one sentence, staring at 'I prefer ...'"}, |
|
{"role": "user", "content": user_message} |
|
], |
|
max_tokens=1000, |
|
) |
|
return response.choices[0].message.content |
|
else: |
|
return "" |
|
|
|
|
|
user_id = "115" |
|
|
|
partial_outfit = ["25479e5dacebbfaed18a7dc4830bd5cd19114486", "becc7b46236e9abb6f6760e7a1569b06bbc236c1", |
|
"180c32b5c8c164f3c632f3e73d6002ccfa6fea57"] |
|
target_category = "Skirts" |
|
summary_user_history(user_id, target_category, partial_outfit) |
|
|
|
|
|
async def compute_text_embedding(text: str) -> np.ndarray: |
|
inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True) |
|
with torch.no_grad(): |
|
text_embedding = clip_model.get_text_features(**inputs).numpy() |
|
return text_embedding / np.linalg.norm(text_embedding, axis=1, keepdims=True) |
|
|
|
|
|
async def find_most_similar_image(text_embedding: np.ndarray) -> Dict[str, Any]: |
|
similarities = np.dot(data_manager.image_embeddings, text_embedding.T).flatten() |
|
most_similar_idx = np.argmax(similarities) |
|
most_similar_item_id = data_manager.item_ids[most_similar_idx] |
|
return { |
|
"image_path": image_paths[most_similar_item_id], |
|
"similarity": float(similarities[most_similar_idx]) |
|
} |
|
|
|
|
|
@mcp.tool() |
|
async def retrieve_image(text: str) -> Dict[str, Any]: |
|
"""Search for the most similar fashion image based on a text description. |
|
|
|
Args: |
|
text (str): Text description of the fashion item to search. |
|
""" |
|
print(f"Searching for {text}") |
|
text_embedding = await compute_text_embedding(text) |
|
return await find_most_similar_image(text_embedding) |
|
|
|
|
|
mcp_server = mcp._mcp_server |
|
sse_transport = SseServerTransport("/messages/") |
|
|
|
|
|
async def handle_sse(request): |
|
print("Handling SSE connection") |
|
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: |
|
read_stream, write_stream = streams |
|
await mcp_server.run( |
|
read_stream, |
|
write_stream, |
|
mcp_server.create_initialization_options(), |
|
) |
|
|
|
|
|
routes = [ |
|
Route("/sse", endpoint=handle_sse), |
|
Mount("/messages/", app=sse_transport.handle_post_message), |
|
] |
|
|
|
|
|
starlette_app = Starlette(routes=routes) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting Image Retrieval server with HTTP and SSE...") |
|
uvicorn.run(starlette_app, host="0.0.0.0", port=8001) |
|
|