|
import json |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import networkx as nx |
|
from typing import List, Dict, Optional |
|
|
|
from langchain_openai.chat_models import ChatOpenAI |
|
from dialog2graph.pipelines.model_storage import ModelStorage |
|
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline |
|
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType |
|
|
|
|
|
def initialize_pipeline(): |
|
ms = ModelStorage() |
|
ms.add( |
|
"my_filling_model", |
|
config={"model_name": "gpt-3.5-turbo"}, |
|
model_type=ChatOpenAI, |
|
) |
|
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model") |
|
|
|
def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[List[Dict[str, str]]]: |
|
"""Load dialog data from JSON file or custom JSON string and ensure it's in list[list[dict]] format""" |
|
if json_file == "custom" and custom_dialog_json: |
|
try: |
|
data = json.loads(custom_dialog_json) |
|
except json.JSONDecodeError as e: |
|
gr.Error(f"Invalid JSON format in custom dialog: {str(e)}") |
|
return [] |
|
else: |
|
file_path = f"{json_file}.json" |
|
try: |
|
with open(file_path, 'r') as f: |
|
data = json.load(f) |
|
except FileNotFoundError: |
|
gr.Error(f"File {file_path} not found!") |
|
return [] |
|
except json.JSONDecodeError: |
|
gr.Error(f"Invalid JSON format in {file_path}!") |
|
return [] |
|
|
|
|
|
if not data: |
|
return [] |
|
|
|
|
|
if isinstance(data, list) and len(data) > 0: |
|
if isinstance(data[0], list): |
|
|
|
return data |
|
elif isinstance(data[0], dict): |
|
|
|
return [data] |
|
|
|
|
|
return [data] if isinstance(data, list) else [[data]] |
|
|
|
def create_network_visualization(graph: nx.Graph) -> go.Figure: |
|
"""Create a Plotly network visualization from NetworkX graph""" |
|
|
|
|
|
pos = nx.spring_layout(graph, k=1, iterations=50) |
|
|
|
|
|
node_x = [] |
|
node_y = [] |
|
node_text = [] |
|
node_ids = [] |
|
|
|
for node in graph.nodes(): |
|
x, y = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
|
|
|
|
node_attrs = graph.nodes[node] |
|
node_label = node_attrs.get('label', str(node)) |
|
node_text.append(f"Node {node}<br>{node_label}") |
|
node_ids.append(node) |
|
|
|
|
|
edge_x = [] |
|
edge_y = [] |
|
edge_info = [] |
|
|
|
for edge in graph.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
|
|
edge_attrs = graph.edges[edge] |
|
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}") |
|
edge_info.append(edge_label) |
|
|
|
|
|
edge_trace = go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=2, color='#888'), |
|
hoverinfo='none', |
|
mode='lines' |
|
) |
|
|
|
|
|
node_trace = go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers+text', |
|
hoverinfo='text', |
|
hovertext=node_text, |
|
text=[str(node) for node in node_ids], |
|
textposition="middle center", |
|
marker=dict( |
|
size=20, |
|
line=dict(width=2) |
|
) |
|
) |
|
|
|
|
|
node_adjacencies = [] |
|
for node in graph.nodes(): |
|
node_adjacencies.append(len(list(graph.neighbors(node)))) |
|
|
|
|
|
node_trace.marker = dict( |
|
showscale=True, |
|
colorscale='YlGnBu', |
|
reversescale=True, |
|
color=node_adjacencies, |
|
size=20, |
|
colorbar=dict( |
|
thickness=15, |
|
len=0.5, |
|
x=1.02, |
|
title="Node Connections", |
|
xanchor="left" |
|
), |
|
line=dict(width=2) |
|
) |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout( |
|
title=dict( |
|
text='Dialog Graph Visualization', |
|
font=dict( |
|
size=16, |
|
), |
|
), |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
annotations=[ dict( |
|
text="Hover over nodes for more information", |
|
showarrow=False, |
|
xref="paper", yref="paper", |
|
x=0.005, y=-0.002, |
|
xanchor='left', yanchor='bottom', |
|
font=dict(color="#888", size=12) |
|
)], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
plot_bgcolor='white' |
|
)) |
|
|
|
return fig |
|
|
|
def create_chat_visualization(dialog_data: List[Dict[str, str]], dialog_index: int = 0, total_dialogs: int = 1) -> str: |
|
"""Create a chat-like visualization of the dialog with navigation info""" |
|
chat_html = f""" |
|
<div style="margin-bottom: 10px; text-align: center; font-weight: bold; color: #666;"> |
|
Dialog {dialog_index + 1} of {total_dialogs} |
|
</div> |
|
<div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;"> |
|
""" |
|
|
|
for i, turn in enumerate(dialog_data): |
|
participant = turn['participant'] |
|
text = turn['text'] |
|
|
|
if participant == 'assistant': |
|
|
|
chat_html += f""" |
|
<div style="display: flex; justify-content: flex-start; margin-bottom: 15px;"> |
|
<div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);"> |
|
<div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div> |
|
<div style="color: #333; line-height: 1.4;">{text}</div> |
|
</div> |
|
</div> |
|
""" |
|
else: |
|
|
|
chat_html += f""" |
|
<div style="display: flex; justify-content: flex-end; margin-bottom: 15px;"> |
|
<div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);"> |
|
<div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div> |
|
<div style="color: #333; line-height: 1.4;">{text}</div> |
|
</div> |
|
</div> |
|
""" |
|
|
|
chat_html += "</div>" |
|
return chat_html |
|
|
|
def process_dialog_and_visualize(dialog_choice: str, custom_dialog: str = "", current_dialog_index: int = 0) -> tuple: |
|
"""Process the selected dialog and create visualization""" |
|
try: |
|
|
|
dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None) |
|
|
|
if not dialog_data_list: |
|
return None, "Failed to load dialog data", "", 0, len(dialog_data_list), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
current_dialog_index = max(0, min(current_dialog_index, len(dialog_data_list) - 1)) |
|
|
|
|
|
pipe = initialize_pipeline() |
|
|
|
|
|
data = PipelineRawDataType(dialogs=dialog_data_list) |
|
graph, report = pipe.invoke(data) |
|
|
|
|
|
fig = create_network_visualization(graph.graph) |
|
|
|
|
|
current_dialog = dialog_data_list[current_dialog_index] |
|
chat_viz = create_chat_visualization(current_dialog, current_dialog_index, len(dialog_data_list)) |
|
|
|
|
|
num_nodes = graph.graph.number_of_nodes() |
|
num_edges = graph.graph.number_of_edges() |
|
total_turns = sum(len(dialog) for dialog in dialog_data_list) |
|
|
|
summary = f""" |
|
## Graph Summary |
|
- **Number of nodes**: {num_nodes} |
|
- **Number of edges**: {num_edges} |
|
- **Total dialogs**: {len(dialog_data_list)} |
|
- **Total dialog turns**: {total_turns} |
|
- **Currently viewing**: Dialog {current_dialog_index + 1} ({len(current_dialog)} turns) |
|
|
|
## Processing Report |
|
Generated graph from {len(dialog_data_list)} dialog(s) with {total_turns} total turns resulting in {num_nodes} nodes and {num_edges} edges. |
|
""" |
|
|
|
|
|
show_nav = len(dialog_data_list) > 1 |
|
|
|
return (fig, summary, chat_viz, current_dialog_index, len(dialog_data_list), |
|
gr.update(visible=show_nav), gr.update(visible=show_nav)) |
|
|
|
except Exception as e: |
|
return None, f"Error processing dialog: {str(e)}", "", 0, 0, gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
def create_gradio_app(): |
|
with gr.Blocks(title="Dialog2Graph Visualizer") as app: |
|
gr.Markdown("# Dialog2Graph Interactive Visualizer") |
|
gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.") |
|
|
|
|
|
current_dialog_index_state = gr.State(0) |
|
total_dialogs_state = gr.State(0) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
dialog_selector = gr.Radio( |
|
choices=["dialog1", "dialog2", "dialog3", "dialog4", "custom"], |
|
label="Select Dialog Dataset", |
|
value="dialog1", |
|
info="Choose one of the available dialog datasets or use custom JSON" |
|
) |
|
|
|
custom_dialog_input = gr.Textbox( |
|
label="Custom Dialog JSON", |
|
placeholder='[{"text": "Hello! How can I help?", "participant": "assistant"}, {"text": "I need assistance", "participant": "user"}]', |
|
lines=8, |
|
visible=False, |
|
info="Enter dialog data as JSON array with 'text' and 'participant' fields" |
|
) |
|
|
|
process_btn = gr.Button( |
|
"Process Dialog & Generate Graph", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
with gr.Accordion("Dialog Datasets Info", open=False): |
|
gr.Markdown(""" |
|
- **dialog1**: Hotel booking conversation |
|
- **dialog2**: Food delivery conversation |
|
- **dialog3**: Technical support conversation |
|
- **dialog4**: Multiple dialogs (calendar, support, subscription) |
|
- **custom**: Provide your own dialog as JSON |
|
""") |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot(label="Graph Visualization") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
summary_output = gr.Markdown(label="Analysis Summary") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### Dialog Conversation") |
|
|
|
|
|
with gr.Row(visible=False) as nav_row: |
|
prev_btn = gr.Button("β Previous Dialog", size="sm") |
|
next_btn = gr.Button("Next Dialog β", size="sm") |
|
|
|
chat_output = gr.HTML(label="Chat Visualization") |
|
|
|
|
|
def navigate_dialog(direction: int, current_index: int, total_dialogs: int, dialog_choice: str, custom_dialog: str): |
|
if total_dialogs <= 1: |
|
return current_index, "", "" |
|
|
|
new_index = current_index + direction |
|
new_index = max(0, min(new_index, total_dialogs - 1)) |
|
|
|
try: |
|
dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None) |
|
if dialog_data_list and new_index < len(dialog_data_list): |
|
current_dialog = dialog_data_list[new_index] |
|
chat_viz = create_chat_visualization(current_dialog, new_index, len(dialog_data_list)) |
|
|
|
total_turns = sum(len(dialog) for dialog in dialog_data_list) |
|
|
|
summary = f""" |
|
## Graph Summary |
|
- **Total dialogs**: {len(dialog_data_list)} |
|
- **Total dialog turns**: {total_turns} |
|
- **Currently viewing**: Dialog {new_index + 1} ({len(current_dialog)} turns) |
|
|
|
## Processing Report |
|
Navigate between dialogs to view different conversations. |
|
""" |
|
|
|
return new_index, summary, chat_viz |
|
except Exception: |
|
pass |
|
|
|
return current_index, "", "" |
|
|
|
|
|
def toggle_custom_input(choice): |
|
return gr.update(visible=(choice == "custom")) |
|
|
|
dialog_selector.change( |
|
fn=toggle_custom_input, |
|
inputs=[dialog_selector], |
|
outputs=[custom_dialog_input] |
|
) |
|
|
|
process_btn.click( |
|
fn=process_dialog_and_visualize, |
|
inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state], |
|
outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row] |
|
) |
|
|
|
|
|
prev_btn.click( |
|
fn=lambda curr_idx, total, choice, custom: navigate_dialog(-1, curr_idx, total, choice, custom), |
|
inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input], |
|
outputs=[current_dialog_index_state, summary_output, chat_output] |
|
) |
|
|
|
next_btn.click( |
|
fn=lambda curr_idx, total, choice, custom: navigate_dialog(1, curr_idx, total, choice, custom), |
|
inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input], |
|
outputs=[current_dialog_index_state, summary_output, chat_output] |
|
) |
|
|
|
|
|
def auto_process(choice, custom_text, curr_idx): |
|
if choice != "custom": |
|
return process_dialog_and_visualize(choice, custom_text, curr_idx) |
|
else: |
|
return None, "Select 'Process Dialog & Generate Graph' to process custom dialog", "", 0, 0, gr.update(visible=False), gr.update(visible=False) |
|
|
|
dialog_selector.change( |
|
fn=auto_process, |
|
inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state], |
|
outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_gradio_app() |
|
app.launch() |
|
|