File size: 6,563 Bytes
f8a73ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
import os
import pandas as pd
import uuid
from openai import AsyncOpenAI
from dotenv import load_dotenv
import chainlit as cl
from system_message import SYSTEM_MESSAGE
from mcp_client import MCPClient
from utils import create_image_grid
import httpx
# Load environment variables from .env
load_dotenv()
# Initialize OpenAI client
CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT")
PROXY = os.getenv("PROXY")
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet")
item_id_set = set(items_df.item_id)
http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client()
class FashionAgent:
def __init__(self, user_id=None):
self.mcp_client = MCPClient("mcp_server_config.json", user_id)
self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client)
self.user_id = user_id
# 全局 FashionAgent 实例
agent = FashionAgent(user_id=None)
@cl.on_chat_start
async def on_chat_start():
await agent.mcp_client.connect_to_servers()
cl.user_session.set("agent", agent)
await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send()
@cl.on_message
async def on_message(message: cl.Message):
agent = cl.user_session.get("agent")
user_id = cl.user_session.get("user_id")
chat_history = cl.user_session.get("chat_history", [])
user_message = message.content
upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)]
if len(upload_image) == 1:
user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}"
elif len(upload_image) > 1:
merged_image_path = f".files/{uuid.uuid4().hex}.jpg"
create_image_grid(upload_image[:4], merged_image_path)
user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}"
image_in_database = []
for image in message.elements:
if isinstance(image, cl.Image):
item_id = image.name.split(".")[0]
if item_id in item_id_set:
image_in_database.append(item_id)
if len(image_in_database) > 0:
user_message += f"\nUser id is: {user_id}"
user_message += f"\nlist_of_items are: {image_in_database}"
elif user_id:
user_message += f"\nUser id is: {user_id}"
# Prepare messages for OpenAI API
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
*[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history],
{"role": "user", "content": user_message}
]
# Fetch available tools
available_tools = await agent.mcp_client.get_tools()
# Initial OpenAI API call
response = await agent.openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
max_tokens=1000,
tools=available_tools if available_tools else None,
tool_choice="auto" if available_tools else None
)
# Process the response
response_message = response.choices[0].message
if response_message.tool_calls:
# Handle tool calls
for tool_call in response_message.tool_calls:
tool_name = tool_call.function.name
params = json.loads(tool_call.function.arguments)
try:
print(f"Agent execute {tool_name} with params: {params}")
result = await agent.mcp_client.execute_tool(tool_name, params)
if tool_name == "retrieve_image":
image_path = json.loads(result['result'][0].text)['image_path']
similarity = json.loads(result['result'][0].text)['similarity']
output = f"I found a matching fashion item with a similarity score of {similarity:.2f}"
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
if tool_name == "image_generate":
image_path = result['result'][0].text
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
output = f"Here is the generated image."
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
if tool_name == "fashion_recommend_without_image":
output = result['result'][0].text
await cl.Message(content=output, author="Fashion Agent").send()
if tool_name == "fashion_recommend":
output = json.loads(result['result'][0].text)['recommendation']
# user_preference = json.loads(result['result'][0].text)['user_preference']
# await cl.Message(content=user_preference, author="Fashion Agent").send()
await cl.Message(content=output, author="Fashion Agent").send()
if tool_name == "try_on":
image_path = result['result'][0].text
images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")]
output = f"Here is the virtual try-on image."
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
else:
output = result
except Exception as e:
output = f"Error executing tool {tool_name}: {str(e)}"
# Update chat history
chat_history.append(cl.Message(content=message.content, author="user"))
chat_history.append(cl.Message(content=output, author="assistant"))
cl.user_session.set("chat_history", chat_history)
else:
# Direct response from the model
output = response_message.content
chat_history.append(cl.Message(content=message.content, author="user"))
chat_history.append(cl.Message(content=output, author="assistant"))
cl.user_session.set("chat_history", chat_history)
await cl.Message(content=output, author="Fashion Agent").send()
@cl.on_chat_end
def on_chat_end():
print("Goodbye", cl.user_session.get("id"))
if __name__ == "__main__":
from chainlit.cli import run_chainlit
os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT
run_chainlit(__file__)
|