\documentclass{article} \usepackage{graphicx} \usepackage{hyperref} \usepackage{amsmath} \usepackage{caption} \usepackage{tgtermes} \usepackage{float} \usepackage[a4paper, margin=1in]{geometry} \usepackage{booktabs} \usepackage{algorithm} \usepackage{algorithmicx} \usepackage{algpseudocode} \date{} \begin{document} {\LARGE \bfseries Parallelize Muon with FSDP2 \par} \vspace{1em} % 제목 아래 간격 조정 \section*{Motivation} \begin{figure}[H] \centering \includegraphics[width=0.8\textwidth]{distributed_muon.png} \caption*{Distributed Muon by Moonlight} \end{figure} While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. \begin{figure}[H] \centering \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} \caption*{Execution timeline of Distributed Muon} \end{figure} \begin{itemize} \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient \item \texttt{AG[i]} : AllGather i-th gradient \item \texttt{G[i]} : Gather i-th gradient \item \texttt{SC[i]} : Scatter i-th gradient \end{itemize} \clearpage \section*{Algorithm} \subsection*{Parallel Muon} \begin{algorithm} \caption{Parallel Muon} \textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ \begin{algorithmic}[1] \State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} \State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ \State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} \State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ \State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} \State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ \State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} \If{$\mathbf{r}$ == $\mathbf{R}$} \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ \Else \State $\mathbf{u} \gets None$ \EndIf \State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} \State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ \State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} \State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ \State \textbf{return $\mathbf{p'}$} \end{algorithmic} \end{algorithm} We eliminate redundant computation by assigning each parameter to a specific GPU. However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. \begin{figure}[H] \centering \includegraphics[width=1.0\textwidth]{naive_execution.png} \caption*{Execution timeline of Parallel Muon} \end{figure} \subsection*{Scheduling Sub-Operations} We can schedule the whole sub-operations as follows, due to the following reasons: \begin{itemize} \item There are no dependencies between parameters. \item GPUs can execute computation and communication concurrently. \end{itemize} \begin{figure}[H] \centering \includegraphics[width=1.0\textwidth]{pipelined.png} \caption*{Execution timeline of re-scheduled Parallel Muon} \end{figure} We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. \textbf{[Algorithm]} (To be written) \clearpage \subsection*{Load Balancing} If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ To mitigate this, we apply load balancing based on per-parameter FLOPs. \vspace{1em} \textbf{Imbalanced (Round Robin)} \begin{figure}[H] \centering \includegraphics[width=1.0\textwidth]{imbalance.png} \end{figure} \textbf{After Load Balancing} \begin{figure}[H] \centering \includegraphics[width=1.0\textwidth]{balanced.png} \end{figure} \section*{Implementation} The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. \section*{Evaluation} We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step. \begin{table}[H] \centering \begin{tabular}{@{}lllll@{}} \toprule Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ \midrule 10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ \bottomrule \end{tabular} \end{table} Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. \end{document}