|
import json |
|
import os |
|
import pandas as pd |
|
import uuid |
|
|
|
from openai import AsyncOpenAI |
|
from dotenv import load_dotenv |
|
import chainlit as cl |
|
|
|
from system_message import SYSTEM_MESSAGE |
|
from mcp_client import MCPClient |
|
from utils import create_image_grid |
|
import httpx |
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") |
|
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT") |
|
PROXY = os.getenv("PROXY") |
|
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet") |
|
item_id_set = set(items_df.item_id) |
|
|
|
http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client() |
|
|
|
|
|
class FashionAgent: |
|
def __init__(self, user_id=None): |
|
self.mcp_client = MCPClient("mcp_server_config.json", user_id) |
|
self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client) |
|
self.user_id = user_id |
|
|
|
|
|
|
|
agent = FashionAgent(user_id=None) |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
await agent.mcp_client.connect_to_servers() |
|
cl.user_session.set("agent", agent) |
|
await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send() |
|
|
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
agent = cl.user_session.get("agent") |
|
user_id = cl.user_session.get("user_id") |
|
chat_history = cl.user_session.get("chat_history", []) |
|
|
|
user_message = message.content |
|
|
|
upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)] |
|
if len(upload_image) == 1: |
|
user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}" |
|
elif len(upload_image) > 1: |
|
merged_image_path = f".files/{uuid.uuid4().hex}.jpg" |
|
create_image_grid(upload_image[:4], merged_image_path) |
|
|
|
user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}" |
|
|
|
image_in_database = [] |
|
for image in message.elements: |
|
if isinstance(image, cl.Image): |
|
item_id = image.name.split(".")[0] |
|
if item_id in item_id_set: |
|
image_in_database.append(item_id) |
|
if len(image_in_database) > 0: |
|
user_message += f"\nUser id is: {user_id}" |
|
user_message += f"\nlist_of_items are: {image_in_database}" |
|
elif user_id: |
|
user_message += f"\nUser id is: {user_id}" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": SYSTEM_MESSAGE}, |
|
*[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history], |
|
{"role": "user", "content": user_message} |
|
] |
|
|
|
|
|
available_tools = await agent.mcp_client.get_tools() |
|
|
|
|
|
response = await agent.openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=messages, |
|
max_tokens=1000, |
|
tools=available_tools if available_tools else None, |
|
tool_choice="auto" if available_tools else None |
|
) |
|
|
|
|
|
response_message = response.choices[0].message |
|
if response_message.tool_calls: |
|
|
|
for tool_call in response_message.tool_calls: |
|
tool_name = tool_call.function.name |
|
params = json.loads(tool_call.function.arguments) |
|
try: |
|
print(f"Agent execute {tool_name} with params: {params}") |
|
result = await agent.mcp_client.execute_tool(tool_name, params) |
|
if tool_name == "retrieve_image": |
|
image_path = json.loads(result['result'][0].text)['image_path'] |
|
similarity = json.loads(result['result'][0].text)['similarity'] |
|
output = f"I found a matching fashion item with a similarity score of {similarity:.2f}" |
|
|
|
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")] |
|
await cl.Message(content=output, elements=images, author="Fashion Agent").send() |
|
if tool_name == "image_generate": |
|
image_path = result['result'][0].text |
|
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")] |
|
output = f"Here is the generated image." |
|
await cl.Message(content=output, elements=images, author="Fashion Agent").send() |
|
if tool_name == "fashion_recommend_without_image": |
|
output = result['result'][0].text |
|
await cl.Message(content=output, author="Fashion Agent").send() |
|
if tool_name == "fashion_recommend": |
|
output = json.loads(result['result'][0].text)['recommendation'] |
|
|
|
|
|
await cl.Message(content=output, author="Fashion Agent").send() |
|
if tool_name == "try_on": |
|
image_path = result['result'][0].text |
|
images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")] |
|
output = f"Here is the virtual try-on image." |
|
await cl.Message(content=output, elements=images, author="Fashion Agent").send() |
|
else: |
|
output = result |
|
except Exception as e: |
|
output = f"Error executing tool {tool_name}: {str(e)}" |
|
|
|
|
|
chat_history.append(cl.Message(content=message.content, author="user")) |
|
chat_history.append(cl.Message(content=output, author="assistant")) |
|
cl.user_session.set("chat_history", chat_history) |
|
else: |
|
|
|
output = response_message.content |
|
chat_history.append(cl.Message(content=message.content, author="user")) |
|
chat_history.append(cl.Message(content=output, author="assistant")) |
|
cl.user_session.set("chat_history", chat_history) |
|
await cl.Message(content=output, author="Fashion Agent").send() |
|
|
|
|
|
@cl.on_chat_end |
|
def on_chat_end(): |
|
print("Goodbye", cl.user_session.get("id")) |
|
|
|
|
|
if __name__ == "__main__": |
|
from chainlit.cli import run_chainlit |
|
|
|
os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT |
|
run_chainlit(__file__) |
|
|