Spaces:
Sleeping
Sleeping
File size: 3,999 Bytes
eb957df b92f7ea eb957df 9541538 eb957df 9541538 95325f5 eb957df d66a7b8 9541538 eb957df 748c045 eb957df 748c045 d721842 95325f5 eb957df 748c045 f85adaf d721842 748c045 9541538 eb957df 9541538 4f0412e 9541538 95325f5 eb957df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import gradio as gr
from gradio import ChatMessage
import base64
from llama_index.core import StorageContext, load_index_from_storage
from dotenv import load_dotenv
from retrieve import get_latest_dir, get_latest_html_file
from graph_handler import query_graph_qa, plot_subgraph
from embed_handler import query_rag_qa
from evaluate import evaluate_llm, reasoning_graph, get_coupon
import base64
import time
load_dotenv()
KG_INDEX_PATH = get_latest_dir(os.getenv("GRAPH_DIR"))
KG_PLOT_PATH = get_latest_html_file(os.getenv("GRAPH_VIS"))
RAG_INDEX_PATH = get_latest_dir(os.getenv("EMBEDDING_DIR"))
# Load Graph-RAG index
graph_rag_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=KG_INDEX_PATH)
)
# Load RAG index
rag_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=RAG_INDEX_PATH)
)
def query_tqa(query, search_level):
"""
Query the Graph-RAG and RAG models for a given query.
Args:
query (str): The query to ask the RAGs.
search_level (int): The max search level to use for the Graph RAG.
Returns:
tuple: The response, reference, and reference text for the Graph-RAG and RAG models.
"""
if not query.strip():
raise gr.Error("Please enter a query before asking.")
grag_response, grag_reference, grag_reference_text = query_graph_qa(
graph_rag_index, query, search_level
)
# rag_response, rag_reference, rag_reference_text = query_rag_qa(
# rag_index, query, search_level
# )
print(str(grag_response.response))
return (
str(grag_response.response),
# grag_reference,
# grag_reference_text,
# rag_response,
# rag_reference,
# rag_reference_text,
)
def show_graph():
"""
Show the latest graph visualization in an iframe.
Returns:
str: The HTML content to display the graph visualization in an iframe.
"""
graph_vis_dir = os.getenv("GRAPH_VIS", "graph_vis")
try:
latest_graph = get_latest_html_file(graph_vis_dir)
if latest_graph:
with open(latest_graph, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode()).decode()
iframe_html = f'<iframe src="data:text/html;base64,{encoded_html}" width="100%" height="1000px" frameborder="0"></iframe>'
return iframe_html
else:
return "No graph visualization found."
except Exception as e:
return f"Error: {str(e)}"
def reveal_coupon(query, grag_response):
"""
Get the coupon from the query and response.
Args:
query (str): Query asked to Graph-RAG.
grag_response (str): Response from the Graph-RAG model.
Returns:
str: Coupon with reasoning.
"""
if not query.strip() or not grag_response.strip():
raise gr.Error("Please ask a query and get a response before revealing the coupon.")
coupon = get_coupon(query, grag_response)
return coupon
with gr.Blocks() as demo:
gr.Markdown("# Comfy Virtual Assistant")
chatbot = gr.Chatbot(
label="Comfy Virtual Assistant",
type="messages",
scale=1,
# suggestions = [
# {"text": "How much iphone cost?"},
# {"text": "What phone options do i have ?"}
# ],
)
msg = gr.Textbox(label="Input Your Query")
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
bot_message = query_tqa(message, 2)
# chat_history.append((message, bot_message))
chat_history.append(ChatMessage(role="user", content=message))
chat_history.append(ChatMessage(role="assistant", content=bot_message))
time.sleep(1)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
demo.launch(auth=(os.getenv("ID"), os.getenv("PASS")), share=False)
|