File size: 3,429 Bytes
eb957df
 
b92f7ea
eb957df
 
 
 
 
 
 
 
9541538
eb957df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9541538
 
 
95325f5
eb957df
dd91dc5
eb957df
 
748c045
eb957df
dd91dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748c045
dd91dc5
 
 
 
 
748c045
dd91dc5
 
 
 
 
748c045
dd91dc5
 
748c045
dd91dc5
 
 
 
d721842
dd91dc5
 
748c045
dd91dc5
 
9541538
eb957df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import gradio as gr
from gradio import ChatMessage
import base64
from llama_index.core import StorageContext, load_index_from_storage
from dotenv import load_dotenv
from retrieve import get_latest_dir, get_latest_html_file
from graph_handler import query_graph_qa, plot_subgraph
from embed_handler import query_rag_qa
from evaluate import evaluate_llm, reasoning_graph, get_coupon
import base64
import time

load_dotenv()

KG_INDEX_PATH = get_latest_dir(os.getenv("GRAPH_DIR"))
KG_PLOT_PATH = get_latest_html_file(os.getenv("GRAPH_VIS"))
RAG_INDEX_PATH = get_latest_dir(os.getenv("EMBEDDING_DIR"))

# Load Graph-RAG index
graph_rag_index = load_index_from_storage(
    StorageContext.from_defaults(persist_dir=KG_INDEX_PATH)
)

# Load RAG index
rag_index = load_index_from_storage(
    StorageContext.from_defaults(persist_dir=RAG_INDEX_PATH)
)


def query_tqa(query, search_level):
    """
    Query the Graph-RAG and RAG models for a given query.

    Args:
    query (str): The query to ask the RAGs.
    search_level (int): The max search level to use for the Graph RAG.

    Returns:
    tuple: The response, reference, and reference text for the Graph-RAG and RAG models.
    """

    if not query.strip():
        raise gr.Error("Please enter a query before asking.")

    grag_response, grag_reference, grag_reference_text = query_graph_qa(
        graph_rag_index, query, search_level
    )
    # rag_response, rag_reference, rag_reference_text = query_rag_qa(
    #     rag_index, query, search_level
    # )
    print(str(grag_response.response))
    return (
        str(grag_response.response)
    )



 
# with gr.Blocks() as demo:
#     gr.Markdown("# Comfy Virtual Assistant")
#     chatbot = gr.Chatbot(
#         label="Comfy Virtual Assistant",
#         type="messages",
#         scale=1,
#         # suggestions = [
#         #         {"text": "How much iphone cost?"},
#         #         {"text": "What phone options do i have ?"}
#         #         ],
           
#     )
#     msg = gr.Textbox(label="Input Your Query")
#     clear = gr.ClearButton([msg, chatbot])

#     def respond(message, chat_history):
#             bot_message = query_tqa(message, 2)
#             # chat_history.append((message, bot_message))
#             chat_history.append(ChatMessage(role="user", content=message))
#             chat_history.append(ChatMessage(role="assistant", content=bot_message))
#             time.sleep(1)
#             return "", chat_history
    
#     msg.submit(respond, [msg, chatbot], [msg, chatbot])

def chatbot_response(message: str, history: List[Tuple[str, str]]) -> str:
    # Use the query_tqa function to get the response
    search_level = 2  # You can adjust this or make it configurable
    response = query_tqa(message, search_level)
    return response

# Create the Gradio interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        bot_message = chatbot_response(history[-1][0], history)
        history[-1][1] = bot_message
        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

    
demo.launch(auth=(os.getenv("ID"), os.getenv("PASS")), share=False)