diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 1041b2a1fb..d8c64410d8 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -549,6 +549,64 @@ def __getattr__(self, name): return getattr(self._objective, name) +@functools.partial(jit, static_argnames=["op"]) +def _proximal_adjoint_correction_pure( + constraint, + xf, + constants, + r_eq_full, + eq_feasible_tangents, + dxdc, + op, +): + """Compute proximal adjoint correction without forming full tangents. + + Parameters + ---------- + constraint : ObjectiveFunction + Inner equilibrium constraint F. + xf : ndarray + Full equilibrium state vector at the projected point. + constants : list + Constraint constants. + r_eq_full : ndarray, shape (dim_x_eq_full,) + Cotangent on the equilibrium full state coming from outer objective: + r_eq_full = dG/dx_eq_full^T @ G. + eq_feasible_tangents : ndarray, shape (dim_x_eq_full, dim_x_eq_red) + Tangent map from reduced equilibrium variables to full equilibrium state. + dxdc : ndarray, shape (dim_x_eq_full, dim_c_eq) + Tangent map from optimizer-facing equilibrium variables c_eq + (Rb_lmn, Zb_lmn, profiles, etc.) to full equilibrium state. + op : str + One of {"scaled", "scaled_error", "unscaled"}. + + Returns + ------- + correction_eq_c : ndarray, shape (dim_c_eq,) + The proximal correction term + (dF/dc)^T (dF/dx @ dx_tangents)^(-T) (dx_tangents)^T r_eq_full. + """ + A = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants).T + + rhs = eq_feasible_tangents.T @ r_eq_full + + # Solve A^T lambda = rhs with same SVD regularization style as proximal_jvp_f_pure + cutoff = jnp.finfo(A.dtype).eps * max(A.shape) + u, s, vt = jnp.linalg.svd(A, full_matrices=False) + s = s + s[-1] + sinv = jnp.where(s < cutoff * s[0], 0.0, 1.0 / s) + + # lambda = A^{-T} rhs = U diag(sinv) V^T rhs + lam = u @ (sinv * (vt @ rhs)) + + # w_full = (dF/dx_full)^T lambda + w_full = getattr(constraint, "vjp_" + op)(lam, xf, constants) + + # correction in optimizer-facing equilibrium coordinates: + # (dF/dc)^T lambda = dxdc^T (dF/dx_full)^T lambda + return dxdc.T @ w_full + + class ProximalProjection(ObjectiveFunction): """Remove equilibrium constraint by projecting onto constraint at each step. @@ -1040,7 +1098,9 @@ def compute_unscaled(self, x, constants=None): return self._objective.compute_unscaled(xopt, constants[0]) def grad(self, x, constants=None): - """Compute gradient of self.compute_scalar. + """Compute gradient of self.compute_scalar using an adjoint proximal solve. + + This avoids forming the full proximal tangent matrix. Parameters ---------- @@ -1052,30 +1112,51 @@ def grad(self, x, constants=None): Returns ------- g : ndarray - gradient vector. - + Gradient vector. """ - # We are looking for the gradient of L = 0.5 * G.T @ G - # Then, the gradient is ∇L = G.T @ J_of_G - # where J_of_G is the Jacobian of G with respect to the optimization variables - # We explained getting J_of_G in the _jvp method. It is basically, - # J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] - # where ∇G is the Jacobian of G with respect to full state vector - # and ∇F is the Jacobian of F with respect to full state vector. Then, - # ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] - # We get the part in [] using the _get_tangent method. - v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) + + # Project current optimizer variables onto force balance. xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._get_tangent(u, xf, constants, op="scaled_error") - tangents = batched_vectorize( - jvpfun, - signature="(n)->(k)", - chunk_size=self._constraint._jac_chunk_size, - )(v) - g = self._objective.compute_scaled_error(xg, constants[0]) - g_vjp = self._objective.vjp_scaled_error(g, xg, constants[0]) - return tangents @ g_vjp + + # Outer residual G and outer pullback r = (dG/dx_full)^T @ G. + G = jnp.atleast_1d(self._objective.compute_scaled_error(xg, constants[0])) + r_full = self._objective.vjp_scaled_error(G, xg, constants[0]) + + # Split full-state cotangent by thing. + r_parts = jnp.split(r_full, np.cumsum(self._dimx_per_thing)) + r_eq_full = r_parts[self._eq_idx] + + # Direct term: dc_tangents^T @ r. + # For non-eq things c == x, so this is just the corresponding slice of r_full. + # For eq-facing optimizer variables c_eq, the direct map is dxdc. + grad_parts = [] + for i, ri in enumerate(r_parts): + if i == self._eq_idx: + grad_parts.append(self._dxdc.T @ ri) + else: + grad_parts.append(ri) + grad_direct = jnp.concatenate(grad_parts) + + # Proximal correction term: + # (dF/dc)^T (dF/dx @ dx_tangents)^(-T) dx_tangents^T r_eq_full + correction_eq_c = _proximal_adjoint_correction_pure( + self._constraint, + xf, + constants[1], + r_eq_full, + self._eq_solve_objective._feasible_tangents, + self._dxdc, + "scaled_error", + ) + + correction_parts = [ + jnp.zeros(dim, dtype=grad_direct.dtype) for dim in self._dimc_per_thing + ] + correction_parts[self._eq_idx] = correction_eq_c + grad_correction = jnp.concatenate(correction_parts) + + return grad_direct - grad_correction def hess(self, x, constants=None): """Compute Hessian of self.compute_scalar.