77 777777 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 44444444446000000000 444446000000 446800 40 4444 77 777777 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 44444444446000000000 444446000000 446800 40 4444 77 777777 777777777777 777777777777777777 7777777777777777777777777 777777777777777777777777777777 77777777777777777777777777777777777777 77777777777777777777777777777777777777777777 177777777777777777777777777777777777777777777772 445377777777777777777777777777777777777777775800 444445377777777777777777777777777777777715000000 444444445277777777777777777777777777716800000000 444444444444217777777777777777777736000000000000 444444444444444517777777777777739000000000000000 444444444444444444517777777728000000000000000000 444444444444444444444537748000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 444444444444444444444446000000000000000000000000 44444444444444444444460000000000000000000000 44444444444444444460000000000000000000 444444444444446000000000000000 44444444444460000000000000 44444444446000000000 444446000000 446800 40 4444
When we train large neural networks, we need to keep them healthy. We do not want the tensors in the network—either the weights, activations or gradients—to grow too large or too small. Very small and very large tensors cause a variety of problems not just limited to numerical underflow and overflow. For example, weight matrices changing size during training makes it harder to design training algorithms—since the relative size of updates to weights has a significant impact on the speed of learning.
The gold standard for keeping tensors healthy is to normalize them. Normalization is commonplace for activation vectors, where we use techniques like layer norm to put the activations on a good scale before passing them to the next layer. It is also commonplace to normalize gradient updates, where we can interpret fast training algorithms like the Muon optimizer as spectrally normalizing the updates. Normalization provides us with certainty about the sizes of tensors—without needing to check Wandb!—and when training large neural networks with many interacting components, having certainty about the network internals is valuable.
Normalization is less commonly applied to weight matrices, although it is not unheard of. For example, the EDM2 diffusion model codebase uses weight constraints and the authors report benefits in their paper. Various other techniques have been proposed but are not common practice in modern large-scale training. For some examples, see Salimans et al, 2016, Miyato et al, 2018 and our paper Liu et al, 2021. Normalizing the weight matrices might be a good idea for a few reasons. Weight constraints make understanding the relative size of optimization updates easier. They remove the problem of weight norms exploding. They allow us to focus hyperparameter tuning effort on tensors whose size matters most. They can force matrices to have a small condition number, making their behaviour more predictable. And relatedly, weight constraints facilitate Lipschitz guarantees for robustness to perturbations.
This post covers one appealing way to constrain the weight matrices of a neural network—by keeping the tensors constrained to submanifolds at each layer. This opens the door to re-thinking optimization, as we can co-design optimization algorithms with these manifold constraints. As an example, we propose This algorithm builds on work from Jianlin Su and Franz Louis Cesista, as discussed further below. a manifold version of the Muon optimizer whose weights are constrained to the Stiefel manifold: the manifold of matrices with unit condition number. We conclude the post by defining the idea of a modular manifold, which is a composable manifold that attempts to make it easier to scale up and train large networks.
Our goal in writing this post is to provide an introduction to a research area that we are excited about, and highlight many directions for future work. We would love to see more work from the community on the topics mentioned at the end of the post!
The shape of a manifold optimizer
This section works through the simplest example of learning on a manifold: a vector parameter constrained to a hypersphere in $\mathbb{R}^d$. The vector parameter is trained to minimize a loss function defined over the full space $\mathbb{R}^d$. This setup might be useful for, say, individual embedding vectors in a transformer model. This section will be a good warmup for the following section on manifold Muon that considers matrix parameters.
We will not be too formal about the definition of a manifold here: it is enough to understand that a manifold is a curved surface that looks flat when you zoom in close enough. The locally flat approximation at a point on the manifold is called the tangent space to the manifold, as visualized in Figure :
The sphere in three dimensions—or the hypersphere in higher dimensions—is a manifold. The locally flat approximation at a point on the manifold is called the tangent space to the manifold and is visualized as the red plane in the figure.
We can characterize the hypersphere in $d$ dimensions as the set of points $w \in \mathbb{R}^d$ of unit Euclidean norm. And the tangent space at a point $w$ on the hypersphere is the set of all vectors $a \in \mathbb{R}^d$ that are orthogonal to $w$.
To keep the weights constrained to the manifold, we could use a non-manifold optimizer and just project the weights back to the manifold after each step. Instead, we are interested in designing methods that take steps in the tangent space. The reason is that we would like to be able to equate the learning rate of our optimizer with the actual length of the optimization step. But if the optimization steps are pointing significantly off manifold and then being projected back, this nice property does not hold. Similar motivation is given in Section 2.3 of the EDM2 paper.
Before we can design a training algorithm for this manifold, something important we need to decide on is how to measure distance For a manifold to be “Riemannian”, the distance measure must be induced by an inner product. The Euclidean ($\ell_2$) norm is induced by an inner product, but the Manhattan ($\ell_1$) distance is not. in the tangent space. A common choice is the Euclidean distance, but we could also choose to measure distance in other ways, as visualized in Figure . In the next section, we will talk about choosing a distance measure based on the functionality of the module.
Inscribing unit balls in the tangent space for different distance measures. The $\ell_2$ (Euclidean) unit ball is a circle while the $\ell_1$ (Manhattan) unit ball is a diamond.
Crucially, the choice of distance measure changes the direction of the best optimization step. If the distance measure is non-Euclidean, then for a fixed length step, we may be able to move further in the direction of the gradient By gradient, we mean the partial derivative of the loss with respect to the weights. Mathematicians reserve the term gradient for something else in Riemannian geometry. by not following the gradient direction exactly! This concept is visualized in Figure .
How geometry influences the direction of the best optimization step. The pink arrow represents the raw gradient—meaning the partial derivative of the loss with respect to the weights. The yellow diamond denotes the $\ell_1$ unit ball. The green arrow is the unit vector pointing furthest in the direction of the gradient. Notice how the green arrow is not parallel to the pink arrow. (In practice, the pink arrow need not lie in the tangent space, although the green arrow will do by construction.) Try dragging the pink arrow to see how the best update direction changes.
To see how this works out in math, we can formulate the optimal update direction given a manifold constraint and a distance measure as itself solving a constrained optimization problem. We will demonstrate this for the case of the hypersphere equipped with the Euclidean norm. Letting $g$ denote the gradient, $w$ the current point on the hypersphere, $a$ the update direction and $\eta$ the learning rate, we need to solve:
Mapping back to the visual language of Figures , and , this formula says that the green arrow (optimal value of $a$) must belong to the red tangent hyperplane ($a^\top w = 0$) and must also lie on a yellow circle of radius $\eta$ ($\|a\|_2 = \eta$). To solve $(\star)$, we can apply the method of Lagrange multipliers. The relevant Lagrangian function is given by:
where $\lambda$ and $\mu$ are Lagrange multipliers. Setting the derivative of the Lagrangian with respect to $a$ to zero and applying the constraints to solve for $\lambda$ and $\mu$, the optimal update $a_\mathrm{opt}$ ends up being given by the following formula:
In words, the optimal update is given by subtracting out the radial component from the gradient, normalizing and multiplying by the learning rate. Since this update lies in the tangent space, actually a very small For a learning rate $\eta$, the effect of the retraction map is $\mathcal{O}(\eta^2)$ small, so the learning rate almost equals the length of the step. correction is needed to stay on the manifold. The correction is known as a “retraction map” and is visualized in Figure :
Visualizing the retraction map. The green arrow is the update taken in the tangent space. Since for large step sizes the tangent space starts to diverge from the manifold, we need to project the updated weights back to the manifold using the retraction map—illustrated by the purple arrow.
We can solve for the retraction map by applying Pythagoras’ theorem to Figure . For a unit hypersphere and a step of length $\eta$, the hypotenuse has length $\sqrt{1+\eta^2}$ and therefore the retraction map for the hypersphere equipped with the Euclidean norm is simply given by dividing the updated weights through by $\sqrt{1+\eta^2}$. Putting everything together, the full manifold optimization algorithm is then given by:
As an exercise for the reader: try calculating the Euclidean norm of the updated weight vector and check that the updated weight vector indeed lies on the hypersphere.
To summarize this section, a first-order manifold optimizer has three steps:
Find the tangent vector of unit length that goes furthest in the gradient direction. Multiply this direction by the learning rate and subtract from the weights; Retract the updated weights back to the manifold.
There are two decisions to make in applying this procedure: what manifold constraint we should use and how we should measure length. By making different decisions, we can generate different optimization algorithms as shown in the following table.
Manifold Norm Optimizer Euclidean $\mathbb{R}^n$ Euclidean norm vanilla gradient descent Euclidean $\mathbb{R}^n$ infinity norm sign gradient descent hypersphere $S^n$ Euclidean norm hyperspherical descent matrix space $\mathbb{R}^{m\times n}$ spectral norm Muon Stiefel manifold $\subset\mathbb{R}^{m\times n}$ spectral norm manifold Muon
We will derive the final algorithm in the table, manifold Muon, in the next section. To design a manifold constraint and a distance function for a matrix parameter, we shall think carefully about the role that a weight matrix plays inside a neural network.
Manifold Muon
A typical weight matrix $W$ in a transformer is a “vector-multiplier”, meaning that it transforms an input vector $x$ into an output vector $y = Wx$. We will design a manifold constraint and a distance function so that the matrix acts in a good way on input vectors: the matrix should not produce excessively small or large outputs, and updates to the matrix should not cause the output vector to change too much or too little.
A good way to think about how a matrix acts on vectors is through the singular value decomposition, illustrated in Figure . The SVD decomposes a matrix in a way that tells us how the matrix stretches input vectors along different axes.
m x n m x k k x k k rank M = k x n = M U Σ V T The singular value decomposition. A matrix $M\in\mathbb{R}^{m\times n}$ of rank $k$ can always be decomposed as $M = U \Sigma V^\top$, where $U\in\mathbb{R}^{m\times k}$ and $V\in\mathbb{R}^{n\times k}$ have orthonormal columns and $\Sigma\in\mathbb{R}^{k\times k}$ is a diagonal matrix with only positive entries. The entries of $\Sigma$ are called the singular values of $M$. The singular values measure the stretching effect that the matrix has on vectors that align with the corresponding columns of $U$ and $V$.
We would like the matrix to have a stretching effect close to one, so we will choose a matrix manifold where all the singular values are exactly one. This matrix manifold is known formally as the Stiefel manifold. We can assume without loss of generality that we are dealing with a tall matrix ($m \geq n$), and then the Stiefel manifold can be equivalently defined as the following set:
Furthermore, one may show that a matrix $A \in \mathbb{R}^{m \times n}$ lies tangent Notice that the Stiefel constraint $W^T W = I_n$ directly generalizes the hyperspherical constraint $w^\top w = 1$ from the previous section. Similarly, the tangent space condition generalizes the hyperspherical one that $a^\top w = 0$. to the Stiefel manifold at matrix $W$ if and only if:
To design a manifold optimizer for the Stiefel manifold, all that remains is to choose a distance function. To limit the maximum stretching effect the weight update can have on an input vector, we will choose the spectral norm, which measures the largest singular value of a matrix. Although this only limits the maximum effect the update can have, since the optimizer we derive will saturate this bound, it will turn out to prevent the minimum effect of the update from being too small. There are some exceptions to this statement, such as when a weight matrix has a fan-out less than its fan-in, in which case we cannot escape from the matrix and its updates having a null space and mapping some inputs to zero.
The idea of doing gradient descent under a spectral norm constraint is what led to the Muon optimizer and, when combined with the Stiefel manifold constraint, we obtain a problem that we shall call manifold Muon:
The manifold Muon problem $(\dagger)$ directly generalizes problem $(\star)$ from the previous section. Solving $(\dagger)$ is harder than solving $(\star)$, and here we will present a numerical solution inspired I figured out how to solve manifold Muon in the square case late last year, but I was unable to solve the full rectangular case and thus posed the problem as an open problem on the Modula docs. Jianlin Su solved the problem this summer by taking a Lagrangian approach and working out a fixed point iteration on the optimality condition. I saw an early version of Jianlin’s work (which did not quite work yet) and also related work by Franz Louis Cesista, and I was able to work out the dual ascent algorithm presented here. by work done by Jianlin Su and Franz Louis Cesista.
Our key insight is that $(\dagger)$ is a convex optimization problem that may be solved via a standard method known as dual ascent. Here we will just sketch the main idea, but you can find a more detailed derivation on this page.
Similar to Jianlin’s approach, we introduce a matrix of Lagrange multipliers $\Lambda\in\mathbb{R}^{n\times n}$. We then apply a series of transformations to convert the problem $(\dagger)$ from a constrained minimization problem to an unconstrained maximization problem:
Equation (1) reformulates the problem as a saddle point problem: the maximization over $\Lambda$ will send the objective to infinity whenever the tangent space condition is violated. Equation (2) follows by applying properties of the trace and equation (3) follows from Sion’s minimax theorem. The inner minimization in equation (3) is solved by setting $A_\mathrm{opt}(\Lambda) = - \eta \times \operatorname{msign}(G + 2W(\Lambda+\Lambda^\top))$ where $\operatorname{msign}$ is the matrix sign function. The matrix sign function snaps the singular values of a matrix to one. It may be computed efficiently on GPUs via Newton-Schulz iteration or the recent Polar Express algorithm. And we obtain equation (4) by substituting this expression for $A_\mathrm{opt}(\Lambda)$ into equation (3). Equation (4) is known as the “dual problem” to $(\dagger)$ and we can solve it by gradient ascent. After some work, the gradient of the dual function is given by:
where the nuclear norm $\|\cdot\|_\mathrm{nuclear}$ measures the sum of the singular values of a matrix.
Finally, we can write down the manifold Muon algorithm: Note that this algorithm is closely related to Jianlin Su’s solution. Where we run dual ascent, Jianlin’s solution amounts to solving for the maximum of the dual function $H(\Lambda)=0$ via a fixed point iteration.
Run gradient ascent on the dual variable $\Lambda \gets \Lambda + \alpha \times H(\Lambda)$ to solve for $\Lambda_\mathrm{opt}$. Compute the update $A_\mathrm{opt} = - \eta \times \operatorname{msign}(G + 2W(\Lambda_{\mathrm{opt}}+\Lambda_\mathrm{opt}^\top))$. Apply the update to the weights $W \gets W + A_\mathrm{opt}$. Retract the weights back to the manifold $W \gets \operatorname{msign}(W)$.
We ran a very small experiment to sanity check the algorithm and provide a minimal implementation for students or researchers to play with. Each training run finishes in less than a minute. The code is here and see Figure for the setup and results.
2025-09-25T11:22:41.551982 image/svg+xml Matplotlib v3.10.5, https://matplotlib.org/ Training a small MLP for 3 epochs on the CIFAR-10 dataset. The different lightly shaded blue curves show different weight decay settings for AdamW. Results were averaged over 3 random seeds. The manifold Muon optimizer attained higher train and test accuracy than AdamW. The third plot shows the final singular value distribution of the first weight matrix for the best performing learning rate: the singular values after training with manifold Muon are all close to 1. Manifold Muon increased the wall clock time per step compared to AdamW, although this could be improved by running fewer steps of dual ascent or adding momentum to the algorithm and running dual ascent online. Depending on other systems bottlenecks, the overhead may not be an issue.
Modular manifolds
So far in this post, we have discussed manifold constraints for individual parameter tensors and co-designed optimization logic for these constraints. A question we have not answered is: what happens when we combine layers to build networks? Can we think about individual layers in isolation—or do we need to be careful about interactions between layers and modify the optimization logic in response? The goal of this section is to point out that there is a way to extend the reasoning we introduced in the previous two sections to the case of whole networks, and we call this the theory of modular manifolds. The theory of modular manifolds builds on research I did with my friend Tim Large, my postdoc advisor Phillip Isola, my PhD advisor Yisong Yue and many other amazing collaborators. At the end of the section, we provide some links to learn more.
The idea of modular manifolds is to build an abstraction that tells us how to budget learning rates across layers. The actual optimization logic in each layer ends up being the same as what we already worked out, except that the learning rate for a layer is modified depending on where the layer appears in the network. The abstraction rests upon a key observation made in our paper on the modular norm, that budgeting learning rates—both across layers and when scaling up individual layers—is intimately tied to understanding the Lipschitz sensitivity of the network output with respect to the weights. The abstraction tracks this sensitivity as we build the network, and manifold constraints help us get a much tighter understanding of this sensitivity.
The starting point for the abstraction is to think of any neural network module—from a layer to a whole transformer—as a mathematical object with three attributes:
A forward function $f:\mathcal{W} \times \mathcal{X} \to \mathcal{Y}$ that maps from a parameter space $\mathcal{W} = \mathbb{R}^d$ and an input space $\mathcal{X}$ to an output space $\mathcal{Y}$. A submanifold of the weight space $\mathcal{M}\subset\mathcal{W}$ that the weights are constrained to. A norm $\|\cdot\| : \mathcal{W} \to \mathbb{R}$ that acts as a measuring stick on weight space.
For example, a linear module equipped with the spectral norm and constrained to the Stiefel manifold, for which we have already worked out an optimizer, would be written:
Provided that an input $x$ to the $\mathsf{StiefelLinear}$ module has unit $\ell_2$ norm, then $\mathsf{StiefelLinear}$ is Lipschitz with respect to its weights in the module’s assigned norm with Lipschitz constant one: This argument can be extended to the RMS norm on the input and the RMS–RMS operator norm on the weights.
This type of Lipschitz statement helps us understand how to scale weight updates to this module since it gives us a bound on how much the output can change when we perturb the weights. But when we compose two modules, can we automatically compile a Lipschitz statement on the joint weight space of the new module? The answer turns out to be yes, if we follow special rules for building the new module:
The new forward function $f_3$ is given by composing the two existing forward functions $f_1$ and $f_2$: $$f_3((w_1, w_2), x) := f_2(w_2, f_1(w_1, x)). \qquad$$ The new manifold constraint $\mathcal{M}_3$ is just the Cartesian product (see Figure for a fun example) of the two existing manifolds $\mathcal{M}_1$ and $\mathcal{M}_2$: $$\mathcal{M}_3 = \mathcal{M}_1 \times \mathcal{M}_2. \qquad$$ The new norm function is the max of the two existing norm functions weighted by special scalar coefficients $s_1$ and $s_2$. Letting $\|\cdot\|_1$ denote the first module’s norm and $\|\cdot\|_2$ denote the second module’s norm, the new norm $\|\cdot\|_3$ is given by: $$\|(w_1, w_2)\|_3 := \max(s_1\cdot \|w_1\|_1, s_2\cdot \|w_2\|_2). \qquad$$
When we use this composite norm to derive optimizers—following the same recipe we used in the first two sections of this post—we end up deriving separate optimizers for each layer, but the scalar coefficients $s_i$ budget the learning rates across layers.
We give much more detail on this construction, including extending it to other ways of combining modules, in our paper on the modular norm—although the paper does not cover manifold optimization. You can also check out our paper on modular duality for more on building optimizers in the modular norm. The Modula project builds toward a programmatic implementation of this construction.
= x The Cartesian product is a simple way to glue together two manifolds. For example, the product of a line and a disk is a cylinder. We get one copy of the disk at every point on the line.
Directions for future work
We are excited about any research that tries to make neural network training as principled and automatic as the forward pass. The ideas in this post benefitted strongly from interactions with external researchers like Jianlin Su and Franz Louis Cesista. We would love to see more work on these topics from the community.
Some possible directions for future work are:
Modularity. What manifolds should attention heads live on? Should embeddings be constrained differently than unembeddings? We can mix-and-match constraints in different parts of the network, or leave some tensors unconstrained.
What manifolds should attention heads live on? Should embeddings be constrained differently than unembeddings? We can mix-and-match constraints in different parts of the network, or leave some tensors unconstrained. Numerics. Manifold constraints also place constraints on the range of values that individual weight entries can take. Does this impact numerics, or make low-precision training easier?
Manifold constraints also place constraints on the range of values that individual weight entries can take. Does this impact numerics, or make low-precision training easier? Convex optimization. The manifold Muon algorithm involves running dual ascent. Can we apply more sophisticated convex optimization techniques to solve the dual problem faster or more reliably?
The manifold Muon algorithm involves running dual ascent. Can we apply more sophisticated convex optimization techniques to solve the dual problem faster or more reliably? Convergence analysis. How fast do these algorithms converge? Does good conditioning of the weight matrices benefit convergence? Is there more that we can say theoretically?
How fast do these algorithms converge? Does good conditioning of the weight matrices benefit convergence? Is there more that we can say theoretically? Regularization. Manifold constraints implicitly regularize the model. Could we design constraints or tune their radii to improve generalization?
Manifold constraints implicitly regularize the model. Could we design constraints or tune their radii to improve generalization? Architecture-optimizer co-design. While hard manifold constraints may not ultimately be the right way to constrain weight matrices, they exemplify the idea of tightly co-designing optimization algorithms with architecural components. Are there more opportunities here?
While hard manifold constraints may not ultimately be the right way to constrain weight matrices, they exemplify the idea of tightly co-designing optimization algorithms with architecural components. Are there more opportunities here? Non-Riemannian geometry. Most work on manifold optimization works in a Riemannian world where distances are induced by inner products and norm balls are ellipsoids. But neural networks are different: matrices act as operators, and operator norms like the spectral norm do not emerge from inner products. This implies, for example, that norm balls can have sharp corners and there is no unique gradient flow. Is there more to be discovered in this non-Riemannian world?
Most work on manifold optimization works in a Riemannian world where distances are induced by inner products and norm balls are ellipsoids. But neural networks are different: matrices act as operators, and operator norms like the spectral norm do not emerge from inner products. This implies, for example, that norm balls can have sharp corners and there is no unique gradient flow. Is there more to be discovered in this non-Riemannian world? Practical implementation. Applying these techniques at scale requires efficient manifold operations on GPUs. The recent Polar Express paper shows promise for fast matrix sign computation. What other algorithmic innovations do we need?
Further reading
Citation
Please cite this work as:
Jeremy Bernstein, "Modular Manifolds", Thinking Machines Lab: Connectionism, Sep 2025.
Or use the BibTeX citation: