import json
import gradio as gr
import plotly.graph_objects as go
import networkx as nx
from typing import List, Dict, Optional
from langchain_openai.chat_models import ChatOpenAI
from dialog2graph.pipelines.model_storage import ModelStorage
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType
# Initialize the pipeline
def initialize_pipeline():
ms = ModelStorage()
ms.add(
"my_filling_model",
config={"model_name": "gpt-3.5-turbo"},
model_type=ChatOpenAI,
)
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model")
def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[List[Dict[str, str]]]:
"""Load dialog data from JSON file or custom JSON string and ensure it's in list[list[dict]] format"""
if json_file == "custom" and custom_dialog_json:
try:
data = json.loads(custom_dialog_json)
except json.JSONDecodeError as e:
gr.Error(f"Invalid JSON format in custom dialog: {str(e)}")
return []
else:
file_path = f"{json_file}.json"
try:
with open(file_path, 'r') as f:
data = json.load(f)
except FileNotFoundError:
gr.Error(f"File {file_path} not found!")
return []
except json.JSONDecodeError:
gr.Error(f"Invalid JSON format in {file_path}!")
return []
# Convert to list[list[dict]] format if needed
if not data:
return []
# Check if it's already list[list[dict]]
if isinstance(data, list) and len(data) > 0:
if isinstance(data[0], list):
# Already in list[list[dict]] format
return data
elif isinstance(data[0], dict):
# Convert list[dict] to list[list[dict]]
return [data]
# If it's something else, wrap it in double list
return [data] if isinstance(data, list) else [[data]]
def create_network_visualization(graph: nx.Graph) -> go.Figure:
"""Create a Plotly network visualization from NetworkX graph"""
# Get node positions using spring layout
pos = nx.spring_layout(graph, k=1, iterations=50)
# Extract node and edge information
node_x = []
node_y = []
node_text = []
node_ids = []
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Get node attributes if available
node_attrs = graph.nodes[node]
node_label = node_attrs.get('label', str(node))
node_text.append(f"Node {node}
{node_label}")
node_ids.append(node)
# Create edge traces
edge_x = []
edge_y = []
edge_info = []
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
# Get edge attributes if available
edge_attrs = graph.edges[edge]
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}")
edge_info.append(edge_label)
# Create the edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='#888'),
hoverinfo='none',
mode='lines'
)
# Create the node trace
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
hovertext=node_text,
text=[str(node) for node in node_ids],
textposition="middle center",
marker=dict(
size=20,
line=dict(width=2)
)
)
# Color nodes by number of connections
node_adjacencies = []
for node in graph.nodes():
node_adjacencies.append(len(list(graph.neighbors(node))))
# Update marker color
node_trace.marker = dict(
showscale=True,
colorscale='YlGnBu',
reversescale=True,
color=node_adjacencies,
size=20,
colorbar=dict(
thickness=15,
len=0.5,
x=1.02,
title="Node Connections",
xanchor="left"
),
line=dict(width=2)
)
# Create the figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title=dict(
text='Dialog Graph Visualization',
font=dict(
size=16,
),
),
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[ dict(
text="Hover over nodes for more information",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="#888", size=12)
)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
))
return fig
def create_chat_visualization(dialog_data: List[Dict[str, str]], dialog_index: int = 0, total_dialogs: int = 1) -> str:
"""Create a chat-like visualization of the dialog with navigation info"""
chat_html = f"""