Skip to content
125 changes: 103 additions & 22 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,64 @@ def __getattr__(self, name):
return getattr(self._objective, name)


@functools.partial(jit, static_argnames=["op"])
def _proximal_adjoint_correction_pure(
constraint,
xf,
constants,
r_eq_full,
eq_feasible_tangents,
dxdc,
op,
):
"""Compute proximal adjoint correction without forming full tangents.

Parameters
----------
constraint : ObjectiveFunction
Inner equilibrium constraint F.
xf : ndarray
Full equilibrium state vector at the projected point.
constants : list
Constraint constants.
r_eq_full : ndarray, shape (dim_x_eq_full,)
Cotangent on the equilibrium full state coming from outer objective:
r_eq_full = dG/dx_eq_full^T @ G.
eq_feasible_tangents : ndarray, shape (dim_x_eq_full, dim_x_eq_red)
Tangent map from reduced equilibrium variables to full equilibrium state.
dxdc : ndarray, shape (dim_x_eq_full, dim_c_eq)
Tangent map from optimizer-facing equilibrium variables c_eq
(Rb_lmn, Zb_lmn, profiles, etc.) to full equilibrium state.
op : str
One of {"scaled", "scaled_error", "unscaled"}.

Returns
-------
correction_eq_c : ndarray, shape (dim_c_eq,)
The proximal correction term
(dF/dc)^T (dF/dx @ dx_tangents)^(-T) (dx_tangents)^T r_eq_full.
"""
A = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants).T

rhs = eq_feasible_tangents.T @ r_eq_full

# Solve A^T lambda = rhs with same SVD regularization style as proximal_jvp_f_pure
cutoff = jnp.finfo(A.dtype).eps * max(A.shape)
u, s, vt = jnp.linalg.svd(A, full_matrices=False)
s = s + s[-1]
sinv = jnp.where(s < cutoff * s[0], 0.0, 1.0 / s)

# lambda = A^{-T} rhs = U diag(sinv) V^T rhs
lam = u @ (sinv * (vt @ rhs))

# w_full = (dF/dx_full)^T lambda
w_full = getattr(constraint, "vjp_" + op)(lam, xf, constants)

# correction in optimizer-facing equilibrium coordinates:
# (dF/dc)^T lambda = dxdc^T (dF/dx_full)^T lambda
return dxdc.T @ w_full


class ProximalProjection(ObjectiveFunction):
"""Remove equilibrium constraint by projecting onto constraint at each step.

Expand Down Expand Up @@ -1040,7 +1098,9 @@ def compute_unscaled(self, x, constants=None):
return self._objective.compute_unscaled(xopt, constants[0])

def grad(self, x, constants=None):
"""Compute gradient of self.compute_scalar.
"""Compute gradient of self.compute_scalar using an adjoint proximal solve.

This avoids forming the full proximal tangent matrix.

Parameters
----------
Expand All @@ -1052,30 +1112,51 @@ def grad(self, x, constants=None):
Returns
-------
g : ndarray
gradient vector.

Gradient vector.
"""
# We are looking for the gradient of L = 0.5 * G.T @ G
# Then, the gradient is ∇L = G.T @ J_of_G
# where J_of_G is the Jacobian of G with respect to the optimization variables
# We explained getting J_of_G in the _jvp method. It is basically,
# J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)]
# where ∇G is the Jacobian of G with respect to full state vector
# and ∇F is the Jacobian of F with respect to full state vector. Then,
# ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)]
# We get the part in [] using the _get_tangent method.
v = jnp.eye(x.shape[0])
constants = setdefault(constants, [None, None])

# Project current optimizer variables onto force balance.
xg, xf = self._update_equilibrium(x, store=True)
jvpfun = lambda u: self._get_tangent(u, xf, constants, op="scaled_error")
tangents = batched_vectorize(
jvpfun,
signature="(n)->(k)",
chunk_size=self._constraint._jac_chunk_size,
)(v)
g = self._objective.compute_scaled_error(xg, constants[0])
g_vjp = self._objective.vjp_scaled_error(g, xg, constants[0])
return tangents @ g_vjp

# Outer residual G and outer pullback r = (dG/dx_full)^T @ G.
G = jnp.atleast_1d(self._objective.compute_scaled_error(xg, constants[0]))
r_full = self._objective.vjp_scaled_error(G, xg, constants[0])

# Split full-state cotangent by thing.
r_parts = jnp.split(r_full, np.cumsum(self._dimx_per_thing))
r_eq_full = r_parts[self._eq_idx]

# Direct term: dc_tangents^T @ r.
# For non-eq things c == x, so this is just the corresponding slice of r_full.
# For eq-facing optimizer variables c_eq, the direct map is dxdc.
grad_parts = []
for i, ri in enumerate(r_parts):
if i == self._eq_idx:
grad_parts.append(self._dxdc.T @ ri)
else:
grad_parts.append(ri)
grad_direct = jnp.concatenate(grad_parts)

# Proximal correction term:
# (dF/dc)^T (dF/dx @ dx_tangents)^(-T) dx_tangents^T r_eq_full
correction_eq_c = _proximal_adjoint_correction_pure(
self._constraint,
xf,
constants[1],
r_eq_full,
self._eq_solve_objective._feasible_tangents,
self._dxdc,
"scaled_error",
)

correction_parts = [
jnp.zeros(dim, dtype=grad_direct.dtype) for dim in self._dimc_per_thing
]
correction_parts[self._eq_idx] = correction_eq_c
grad_correction = jnp.concatenate(correction_parts)

return grad_direct - grad_correction

def hess(self, x, constants=None):
"""Compute Hessian of self.compute_scalar.
Expand Down