|
import json |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import networkx as nx |
|
from typing import List, Dict, Optional |
|
|
|
from langchain_openai.chat_models import ChatOpenAI |
|
from dialog2graph.pipelines.model_storage import ModelStorage |
|
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline |
|
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType |
|
|
|
|
|
def initialize_pipeline(): |
|
ms = ModelStorage() |
|
ms.add( |
|
"my_filling_model", |
|
config={"model_name": "gpt-3.5-turbo"}, |
|
model_type=ChatOpenAI, |
|
) |
|
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model") |
|
|
|
def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[Dict[str, str]]: |
|
"""Load dialog data from JSON file or custom JSON string""" |
|
if json_file == "custom" and custom_dialog_json: |
|
try: |
|
return json.loads(custom_dialog_json) |
|
except json.JSONDecodeError as e: |
|
gr.Error(f"Invalid JSON format in custom dialog: {str(e)}") |
|
return [] |
|
|
|
file_path = f"{json_file}.json" |
|
try: |
|
with open(file_path, 'r') as f: |
|
return json.load(f) |
|
except FileNotFoundError: |
|
gr.Error(f"File {file_path} not found!") |
|
return [] |
|
except json.JSONDecodeError: |
|
gr.Error(f"Invalid JSON format in {file_path}!") |
|
return [] |
|
|
|
def create_network_visualization(graph: nx.Graph) -> go.Figure: |
|
"""Create a Plotly network visualization from NetworkX graph""" |
|
|
|
|
|
pos = nx.spring_layout(graph, k=1, iterations=50) |
|
|
|
|
|
node_x = [] |
|
node_y = [] |
|
node_text = [] |
|
node_ids = [] |
|
|
|
for node in graph.nodes(): |
|
x, y = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
|
|
|
|
node_attrs = graph.nodes[node] |
|
node_label = node_attrs.get('label', str(node)) |
|
node_text.append(f"Node {node}<br>{node_label}") |
|
node_ids.append(node) |
|
|
|
|
|
edge_x = [] |
|
edge_y = [] |
|
edge_info = [] |
|
|
|
for edge in graph.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
|
|
edge_attrs = graph.edges[edge] |
|
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}") |
|
edge_info.append(edge_label) |
|
|
|
|
|
edge_trace = go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=2, color='#888'), |
|
hoverinfo='none', |
|
mode='lines' |
|
) |
|
|
|
|
|
node_trace = go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers+text', |
|
hoverinfo='text', |
|
hovertext=node_text, |
|
text=[str(node) for node in node_ids], |
|
textposition="middle center", |
|
marker=dict( |
|
size=20, |
|
line=dict(width=2) |
|
) |
|
) |
|
|
|
|
|
node_adjacencies = [] |
|
for node in graph.nodes(): |
|
node_adjacencies.append(len(list(graph.neighbors(node)))) |
|
|
|
|
|
node_trace.marker = dict( |
|
showscale=True, |
|
colorscale='YlGnBu', |
|
reversescale=True, |
|
color=node_adjacencies, |
|
size=20, |
|
colorbar=dict( |
|
thickness=15, |
|
len=0.5, |
|
x=1.02, |
|
title="Node Connections", |
|
xanchor="left" |
|
), |
|
line=dict(width=2) |
|
) |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout( |
|
title=dict( |
|
text='Dialog Graph Visualization', |
|
font=dict( |
|
size=16, |
|
), |
|
), |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
annotations=[ dict( |
|
text="Hover over nodes for more information", |
|
showarrow=False, |
|
xref="paper", yref="paper", |
|
x=0.005, y=-0.002, |
|
xanchor='left', yanchor='bottom', |
|
font=dict(color="#888", size=12) |
|
)], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
plot_bgcolor='white' |
|
)) |
|
|
|
return fig |
|
|
|
def create_chat_visualization(dialog_data: List[Dict[str, str]]) -> str: |
|
"""Create a chat-like visualization of the dialog""" |
|
chat_html = """ |
|
<div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;"> |
|
""" |
|
|
|
for i, turn in enumerate(dialog_data): |
|
participant = turn['participant'] |
|
text = turn['text'] |
|
|
|
if participant == 'assistant': |
|
|
|
chat_html += f""" |
|
<div style="display: flex; justify-content: flex-start; margin-bottom: 15px;"> |
|
<div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);"> |
|
<div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div> |
|
<div style="color: #333; line-height: 1.4;">{text}</div> |
|
</div> |
|
</div> |
|
""" |
|
else: |
|
|
|
chat_html += f""" |
|
<div style="display: flex; justify-content: flex-end; margin-bottom: 15px;"> |
|
<div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);"> |
|
<div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div> |
|
<div style="color: #333; line-height: 1.4;">{text}</div> |
|
</div> |
|
</div> |
|
""" |
|
|
|
chat_html += "</div>" |
|
return chat_html |
|
|
|
def process_dialog_and_visualize(dialog_choice: str, custom_dialog: str = "") -> tuple: |
|
"""Process the selected dialog and create visualization""" |
|
try: |
|
|
|
dialog_data = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None) |
|
|
|
if not dialog_data: |
|
return None, "Failed to load dialog data", "" |
|
|
|
|
|
pipe = initialize_pipeline() |
|
|
|
|
|
data = PipelineRawDataType(dialogs=dialog_data) |
|
graph, report = pipe.invoke(data) |
|
|
|
|
|
fig = create_network_visualization(graph.graph) |
|
|
|
|
|
chat_viz = create_chat_visualization(dialog_data) |
|
|
|
|
|
num_nodes = graph.graph.number_of_nodes() |
|
num_edges = graph.graph.number_of_edges() |
|
|
|
summary = f""" |
|
## Graph Summary |
|
- **Number of nodes**: {num_nodes} |
|
- **Number of edges**: {num_edges} |
|
- **Dialog turns**: {len(dialog_data)} |
|
|
|
## Processing Report |
|
Generated graph from {len(dialog_data)} dialog turns with {num_nodes} nodes and {num_edges} edges. |
|
""" |
|
|
|
return fig, summary, chat_viz |
|
|
|
except Exception as e: |
|
return None, f"Error processing dialog: {str(e)}", "" |
|
|
|
|
|
def create_gradio_app(): |
|
with gr.Blocks(title="Dialog2Graph Visualizer") as app: |
|
gr.Markdown("# Dialog2Graph Interactive Visualizer") |
|
gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
dialog_selector = gr.Radio( |
|
choices=["dialog1", "dialog2", "dialog3", "custom"], |
|
label="Select Dialog Dataset", |
|
value="dialog1", |
|
info="Choose one of the available dialog datasets or use custom JSON" |
|
) |
|
|
|
custom_dialog_input = gr.Textbox( |
|
label="Custom Dialog JSON", |
|
placeholder='[{"text": "Hello! How can I help?", "participant": "assistant"}, {"text": "I need assistance", "participant": "user"}]', |
|
lines=8, |
|
visible=False, |
|
info="Enter dialog data as JSON array with 'text' and 'participant' fields" |
|
) |
|
|
|
process_btn = gr.Button( |
|
"Process Dialog & Generate Graph", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
with gr.Accordion("Dialog Datasets Info", open=False): |
|
gr.Markdown(""" |
|
- **dialog1**: Hotel booking conversation |
|
- **dialog2**: Food delivery conversation |
|
- **dialog3**: Technical support conversation |
|
- **custom**: Provide your own dialog as JSON |
|
""") |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot(label="Graph Visualization") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
summary_output = gr.Markdown(label="Analysis Summary") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### Dialog Conversation") |
|
chat_output = gr.HTML(label="Chat Visualization") |
|
|
|
|
|
def toggle_custom_input(choice): |
|
return gr.update(visible=(choice == "custom")) |
|
|
|
dialog_selector.change( |
|
fn=toggle_custom_input, |
|
inputs=[dialog_selector], |
|
outputs=[custom_dialog_input] |
|
) |
|
|
|
process_btn.click( |
|
fn=process_dialog_and_visualize, |
|
inputs=[dialog_selector, custom_dialog_input], |
|
outputs=[plot_output, summary_output, chat_output] |
|
) |
|
|
|
|
|
def auto_process(choice, custom_text): |
|
if choice != "custom": |
|
return process_dialog_and_visualize(choice, custom_text) |
|
else: |
|
return None, "Select 'Process Dialog & Generate Graph' to process custom dialog", "" |
|
|
|
dialog_selector.change( |
|
fn=auto_process, |
|
inputs=[dialog_selector, custom_dialog_input], |
|
outputs=[plot_output, summary_output, chat_output] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_gradio_app() |
|
app.launch() |
|
|