FashionM3 / chainlit_app.py
pangkaicheng
first commit
f8a73ec
import json
import os
import pandas as pd
import uuid
from openai import AsyncOpenAI
from dotenv import load_dotenv
import chainlit as cl
from system_message import SYSTEM_MESSAGE
from mcp_client import MCPClient
from utils import create_image_grid
import httpx
# Load environment variables from .env
load_dotenv()
# Initialize OpenAI client
CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT")
PROXY = os.getenv("PROXY")
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet")
item_id_set = set(items_df.item_id)
http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client()
class FashionAgent:
def __init__(self, user_id=None):
self.mcp_client = MCPClient("mcp_server_config.json", user_id)
self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client)
self.user_id = user_id
# 全局 FashionAgent 实例
agent = FashionAgent(user_id=None)
@cl.on_chat_start
async def on_chat_start():
await agent.mcp_client.connect_to_servers()
cl.user_session.set("agent", agent)
await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send()
@cl.on_message
async def on_message(message: cl.Message):
agent = cl.user_session.get("agent")
user_id = cl.user_session.get("user_id")
chat_history = cl.user_session.get("chat_history", [])
user_message = message.content
upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)]
if len(upload_image) == 1:
user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}"
elif len(upload_image) > 1:
merged_image_path = f".files/{uuid.uuid4().hex}.jpg"
create_image_grid(upload_image[:4], merged_image_path)
user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}"
image_in_database = []
for image in message.elements:
if isinstance(image, cl.Image):
item_id = image.name.split(".")[0]
if item_id in item_id_set:
image_in_database.append(item_id)
if len(image_in_database) > 0:
user_message += f"\nUser id is: {user_id}"
user_message += f"\nlist_of_items are: {image_in_database}"
elif user_id:
user_message += f"\nUser id is: {user_id}"
# Prepare messages for OpenAI API
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
*[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history],
{"role": "user", "content": user_message}
]
# Fetch available tools
available_tools = await agent.mcp_client.get_tools()
# Initial OpenAI API call
response = await agent.openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
max_tokens=1000,
tools=available_tools if available_tools else None,
tool_choice="auto" if available_tools else None
)
# Process the response
response_message = response.choices[0].message
if response_message.tool_calls:
# Handle tool calls
for tool_call in response_message.tool_calls:
tool_name = tool_call.function.name
params = json.loads(tool_call.function.arguments)
try:
print(f"Agent execute {tool_name} with params: {params}")
result = await agent.mcp_client.execute_tool(tool_name, params)
if tool_name == "retrieve_image":
image_path = json.loads(result['result'][0].text)['image_path']
similarity = json.loads(result['result'][0].text)['similarity']
output = f"I found a matching fashion item with a similarity score of {similarity:.2f}"
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
if tool_name == "image_generate":
image_path = result['result'][0].text
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
output = f"Here is the generated image."
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
if tool_name == "fashion_recommend_without_image":
output = result['result'][0].text
await cl.Message(content=output, author="Fashion Agent").send()
if tool_name == "fashion_recommend":
output = json.loads(result['result'][0].text)['recommendation']
# user_preference = json.loads(result['result'][0].text)['user_preference']
# await cl.Message(content=user_preference, author="Fashion Agent").send()
await cl.Message(content=output, author="Fashion Agent").send()
if tool_name == "try_on":
image_path = result['result'][0].text
images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")]
output = f"Here is the virtual try-on image."
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
else:
output = result
except Exception as e:
output = f"Error executing tool {tool_name}: {str(e)}"
# Update chat history
chat_history.append(cl.Message(content=message.content, author="user"))
chat_history.append(cl.Message(content=output, author="assistant"))
cl.user_session.set("chat_history", chat_history)
else:
# Direct response from the model
output = response_message.content
chat_history.append(cl.Message(content=message.content, author="user"))
chat_history.append(cl.Message(content=output, author="assistant"))
cl.user_session.set("chat_history", chat_history)
await cl.Message(content=output, author="Fashion Agent").send()
@cl.on_chat_end
def on_chat_end():
print("Goodbye", cl.user_session.get("id"))
if __name__ == "__main__":
from chainlit.cli import run_chainlit
os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT
run_chainlit(__file__)