File size: 6,563 Bytes
f8a73ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import json
import os
import pandas as pd
import uuid

from openai import AsyncOpenAI
from dotenv import load_dotenv
import chainlit as cl

from system_message import SYSTEM_MESSAGE
from mcp_client import MCPClient
from utils import create_image_grid
import httpx


# Load environment variables from .env
load_dotenv()

# Initialize OpenAI client
CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT")
PROXY = os.getenv("PROXY")
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet")
item_id_set = set(items_df.item_id)

http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client()


class FashionAgent:
    def __init__(self, user_id=None):
        self.mcp_client = MCPClient("mcp_server_config.json", user_id)
        self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client)
        self.user_id = user_id


# 全局 FashionAgent 实例
agent = FashionAgent(user_id=None)


@cl.on_chat_start
async def on_chat_start():
    await agent.mcp_client.connect_to_servers()
    cl.user_session.set("agent", agent)
    await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send()


@cl.on_message
async def on_message(message: cl.Message):
    agent = cl.user_session.get("agent")
    user_id = cl.user_session.get("user_id")
    chat_history = cl.user_session.get("chat_history", [])

    user_message = message.content

    upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)]
    if len(upload_image) == 1:
        user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}"
    elif len(upload_image) > 1:
        merged_image_path = f".files/{uuid.uuid4().hex}.jpg"
        create_image_grid(upload_image[:4], merged_image_path)

        user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}"

    image_in_database = []
    for image in message.elements:
        if isinstance(image, cl.Image):
            item_id = image.name.split(".")[0]
            if item_id in item_id_set:
                image_in_database.append(item_id)
    if len(image_in_database) > 0:
        user_message += f"\nUser id is: {user_id}"
        user_message += f"\nlist_of_items are: {image_in_database}"
    elif user_id:
        user_message += f"\nUser id is: {user_id}"

    # Prepare messages for OpenAI API
    messages = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        *[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history],
        {"role": "user", "content": user_message}
    ]

    # Fetch available tools
    available_tools = await agent.mcp_client.get_tools()

    # Initial OpenAI API call
    response = await agent.openai.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        max_tokens=1000,
        tools=available_tools if available_tools else None,
        tool_choice="auto" if available_tools else None
    )

    # Process the response
    response_message = response.choices[0].message
    if response_message.tool_calls:
        # Handle tool calls
        for tool_call in response_message.tool_calls:
            tool_name = tool_call.function.name
            params = json.loads(tool_call.function.arguments)
            try:
                print(f"Agent execute {tool_name} with params: {params}")
                result = await agent.mcp_client.execute_tool(tool_name, params)
                if tool_name == "retrieve_image":
                    image_path = json.loads(result['result'][0].text)['image_path']
                    similarity = json.loads(result['result'][0].text)['similarity']
                    output = f"I found a matching fashion item with a similarity score of {similarity:.2f}"

                    images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
                    await cl.Message(content=output, elements=images, author="Fashion Agent").send()
                if tool_name == "image_generate":
                    image_path = result['result'][0].text
                    images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
                    output = f"Here is the generated image."
                    await cl.Message(content=output, elements=images, author="Fashion Agent").send()
                if tool_name == "fashion_recommend_without_image":
                    output = result['result'][0].text
                    await cl.Message(content=output, author="Fashion Agent").send()
                if tool_name == "fashion_recommend":
                    output = json.loads(result['result'][0].text)['recommendation']
                    # user_preference = json.loads(result['result'][0].text)['user_preference']
                    # await cl.Message(content=user_preference, author="Fashion Agent").send()
                    await cl.Message(content=output, author="Fashion Agent").send()
                if tool_name == "try_on":
                    image_path = result['result'][0].text
                    images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")]
                    output = f"Here is the virtual try-on image."
                    await cl.Message(content=output, elements=images, author="Fashion Agent").send()
                else:
                    output = result
            except Exception as e:
                output = f"Error executing tool {tool_name}: {str(e)}"

            # Update chat history
            chat_history.append(cl.Message(content=message.content, author="user"))
            chat_history.append(cl.Message(content=output, author="assistant"))
            cl.user_session.set("chat_history", chat_history)
    else:
        # Direct response from the model
        output = response_message.content
        chat_history.append(cl.Message(content=message.content, author="user"))
        chat_history.append(cl.Message(content=output, author="assistant"))
        cl.user_session.set("chat_history", chat_history)
        await cl.Message(content=output, author="Fashion Agent").send()


@cl.on_chat_end
def on_chat_end():
    print("Goodbye", cl.user_session.get("id"))


if __name__ == "__main__":
    from chainlit.cli import run_chainlit

    os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT
    run_chainlit(__file__)