Victarry's picture
Add interactive server.
5126943
raw
history blame
9.27 kB
import dash
from dash import dcc, html, Input, Output, State, callback_context
import plotly.graph_objects as go
import webbrowser
from threading import Timer
from src.execution_model import ScheduleConfig, Schedule
from src.strategies import (
generate_1f1b_schedule,
generate_zero_bubble_1p_schedule,
generate_1f1b_overlap_schedule,
generate_1f1b_interleave_schedule,
generate_1f1b_interleave_overlap_schedule,
generate_dualpipe_schedule
)
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
def open_browser(port):
webbrowser.open_new(f"http://127.0.0.1:{port}")
STRATEGIES = {
"1f1b": generate_1f1b_schedule,
"zb1p": generate_zero_bubble_1p_schedule,
"1f1b_overlap": generate_1f1b_overlap_schedule,
"1f1b_interleave": generate_1f1b_interleave_schedule,
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
"dualpipe": generate_dualpipe_schedule,
}
app = dash.Dash(__name__, suppress_callback_exceptions=True)
app.title = "Pipeline Parallelism Visualizer"
# Initial default values
default_values = {
"num_devices": 4,
"num_stages": 8,
"num_batches": 16,
"p2p_latency": 0.1,
"op_time_forward": 1.0,
"op_time_backward_d": 1.0,
"op_time_backward_w": 1.0,
"op_time_backward": 2.0,
"strategy": "1f1b_interleave",
"split_backward": False,
"placement_strategy": "interleave"
}
app.layout = html.Div([
html.H1("Pipeline Parallelism Schedule Visualizer", style={'textAlign': 'center'}),
html.Div([
html.Div([
html.Label("Number of Devices (GPUs):"),
dcc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, style={'width': '100%'}),
html.Label("Number of Stages (Model Chunks):"),
dcc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, style={'width': '100%'}),
html.Label("Number of Microbatches:"),
dcc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, style={'width': '100%'}),
html.Label("P2P Latency (ms):"),
dcc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, style={'width': '100%'}),
], style={'padding': 10, 'flex': 1}),
html.Div([
html.Label("Scheduling Strategy:"),
dcc.Dropdown(
id='strategy',
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
value=default_values["strategy"],
clearable=False,
style={'width': '100%'}
),
html.Label("Placement Strategy:"),
dcc.Dropdown(
id='placement_strategy',
options=[
{'label': 'Standard', 'value': 'standard'},
{'label': 'Interleave', 'value': 'interleave'},
{'label': 'DualPipe', 'value': 'dualpipe'}
],
value=default_values["placement_strategy"],
clearable=False,
style={'width': '100%'}
),
html.Div([ # Wrap checkbox and label
dcc.Checklist(
id='split_backward',
options=[{'label': ' Split Backward Pass (for ZB-1P, DualPipe)', 'value': 'True'}],
value=['True'] if default_values["split_backward"] else [],
style={'display': 'inline-block'}
),
], style={'marginTop': '20px'}),
], style={'padding': 10, 'flex': 1}),
html.Div([
html.Label("Operation Time - Forward (ms):"),
dcc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, style={'width': '100%'}),
html.Label("Operation Time - Backward (ms):"),
dcc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01, style={'width': '100%'}),
html.Label("Operation Time - Backward D (Data Grad) (ms):"),
dcc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01, style={'width': '100%'}),
html.Label("Operation Time - Backward W (Weight Grad) (ms):"),
dcc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01, style={'width': '100%'}),
], style={'padding': 10, 'flex': 1}),
], style={'display': 'flex', 'flexDirection': 'row'}),
html.Div([
html.Button('Generate Schedule', id='generate-button', n_clicks=0, style={'margin': '20px auto', 'display': 'block'}),
]),
html.Div(id='error-message', style={'color': 'red', 'textAlign': 'center', 'marginTop': '10px'}),
dcc.Loading(
id="loading-graph",
type="circle",
children=dcc.Graph(id='pipeline-graph', figure=go.Figure())
)
])
@app.callback(
Output('pipeline-graph', 'figure'),
Output('error-message', 'children'),
Input('generate-button', 'n_clicks'),
State('num_devices', 'value'),
State('num_stages', 'value'),
State('num_batches', 'value'),
State('p2p_latency', 'value'),
State('op_time_forward', 'value'),
State('op_time_backward', 'value'),
State('op_time_backward_d', 'value'),
State('op_time_backward_w', 'value'),
State('strategy', 'value'),
State('split_backward', 'value'),
State('placement_strategy', 'value'),
prevent_initial_call=True
)
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
strategy, split_backward_list, placement_strategy):
error_message = ""
fig = go.Figure()
split_backward = 'True' in split_backward_list
# Basic Validations
if not all([num_devices, num_stages, num_batches, op_time_forward]):
return fig, "Missing required input values."
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
return fig, "Backward D and Backward W times are required when 'Split Backward' is checked."
if not split_backward and not op_time_backward:
return fig, "Backward time is required when 'Split Backward' is unchecked."
if num_stages % num_devices != 0 and placement_strategy != 'dualpipe':
return fig, "Number of Stages must be divisible by Number of Devices for standard/interleave placement."
if placement_strategy == 'dualpipe' and num_stages % 2 != 0:
return fig, "DualPipe requires an even number of stages."
if placement_strategy == 'dualpipe' and num_stages != num_devices:
return fig, "DualPipe requires Number of Stages to be equal to Number of Devices."
if strategy == 'dualpipe' and not split_backward:
return fig, "DualPipe strategy currently requires 'Split Backward' to be checked."
if strategy == 'dualpipe' and placement_strategy != 'dualpipe':
return fig, "DualPipe strategy requires 'DualPipe' placement strategy."
if strategy == 'zb1p' and not split_backward:
return fig, "ZB-1P strategy requires 'Split Backward' to be checked."
try:
op_times = {
"forward": float(op_time_forward),
}
if split_backward:
op_times["backward_D"] = float(op_time_backward_d)
op_times["backward_W"] = float(op_time_backward_w)
# Add combined backward time for compatibility if needed by some visualization or calculation
op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
else:
op_times["backward"] = float(op_time_backward)
config = ScheduleConfig(
num_devices=int(num_devices),
num_stages=int(num_stages),
num_batches=int(num_batches),
p2p_latency=float(p2p_latency),
placement_strategy=placement_strategy,
split_backward=split_backward,
op_times=op_times,
)
schedule_func = STRATEGIES.get(strategy)
if not schedule_func:
raise ValueError(f"Invalid strategy selected: {strategy}")
schedule = schedule_func(config)
schedule.execute() # Calculate start/end times
vis_data = convert_schedule_to_visualization_format(schedule)
fig = create_pipeline_figure(vis_data, show_progress=False) # Disable progress bar in server mode
except AssertionError as e:
error_message = f"Configuration Error: {e}"
fig = go.Figure() # Return empty figure on error
except ValueError as e:
error_message = f"Input Error: {e}"
fig = go.Figure()
except Exception as e:
error_message = f"An unexpected error occurred: {e}"
fig = go.Figure()
return fig, error_message
if __name__ == '__main__':
port = 8050
# Timer(1, open_browser, args=(port,)).start() # Optional: automatically open browser
print(f"Dash server running on http://127.0.0.1:{port}")
app.run_server(debug=True, port=port)