diff --git a/CHANGELOG.md b/CHANGELOG.md index 2449a4d868..94dde62afc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,12 @@ v0.17.0 ------- New Features +------------ +- Automatically-differentiable, non-singular Laplace BIE solver. +- Improved performance and accuracy of FFT interpolation in singular integrals + ([1](https://github.com/f0uriest/interpax/pull/116), [2](https://github.com/f0uriest/interpax/pull/117)). + This is useful for free surface optimization. +- [Plumbing for new magnetic field API](https://github.com/PlasmaControl/DESC/issues/1807). - Adds particle tracing capabilities in ``desc.particles`` module. - Particle tracing is done via ``desc.particles.trace_particles`` function. diff --git a/README.rst b/README.rst index 5bd77fe2ba..88e8c0774f 100644 --- a/README.rst +++ b/README.rst @@ -102,7 +102,7 @@ Contribute - `Contributing guidelines `_ - `Issue Tracker `_ - `Source Code `_ -- `Documentation `_ +- `Documentation `_ .. |License| image:: https://img.shields.io/github/license/PlasmaControl/desc?color=blue&logo=open-source-initiative&logoColor=white :target: https://github.com/PlasmaControl/DESC/blob/master/LICENSE diff --git a/desc/backend.py b/desc/backend.py index e893d540f5..5d14a61661 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -65,6 +65,50 @@ def print_backend_info(): ) +def _is_converged(residual, tol): + return jnp.sum(residual * residual) <= tol**2 + + +def _lstsq(A, y): + """Cholesky factorized least-squares. + + jnp.linalg.lstsq doesn't have JVP defined and is slower than needed, + so we use regularized cholesky. + + For square systems, solves Ax=y directly. + """ + A = jnp.atleast_2d(A) + y = jnp.atleast_1d(y) + eps = jnp.sqrt(jnp.finfo(A.dtype).eps) + if A.shape[-2] == A.shape[-1]: + return jnp.linalg.solve(A, y) if y.size > 1 else jnp.squeeze(y / A) + elif A.shape[-2] > A.shape[-1]: + P = A.T @ A + eps * jnp.eye(A.shape[-1]) + return cho_solve(cho_factor(P), A.T @ y) + else: + P = A @ A.T + eps * jnp.eye(A.shape[-2]) + return A.T @ cho_solve(cho_factor(P), y) + + +def _tangent_solve(g, y): + # System is always square. + return _lstsq(jax.jacfwd(g)(y), y) + + +def _tangent_solve_scalar(g, y): + return y / g(1.0) + + +def _map(f, xs, *, batch_size=None, in_axes=0, out_axes=0): + """Generalizes jax.lax.map; uses numpy.""" + if not isinstance(xs, np.ndarray): + raise NotImplementedError( + "Require numpy array input, or install jax to support pytrees." + ) + xs = np.moveaxis(xs, source=in_axes, destination=0) + return np.stack([f(x) for x in xs], axis=out_axes) + + def _diag_to_full(d, e): j = np.arange(d.shape[-1]) return ( @@ -377,7 +421,7 @@ def backtrack(xk1, fk1, d): def condfun(state): xk1, fk1, k1 = state - return (k1 < maxiter) & (jnp.dot(fk1, fk1) > tol**2) + return (k1 < maxiter) & (~_is_converged(fk1, tol)) def bodyfun(state): xk1, fk1, k1 = state @@ -393,41 +437,17 @@ def bodyfun(state): else: return state[0] - def tangent_solve(g, y): - return y / g(1.0) - if full_output: x, (res, niter) = jax.lax.custom_root( - res, x0, solve, tangent_solve, has_aux=True + res, x0, solve, _tangent_solve_scalar, has_aux=True ) return x, (abs(res), niter) else: - x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False) + x = jax.lax.custom_root( + res, x0, solve, _tangent_solve_scalar, has_aux=False + ) return x - def _lstsq(A, y): - """Cholesky factorized least-squares. - - jnp.linalg.lstsq doesn't have JVP defined and is slower than needed, - so we use regularized cholesky. - - For square systems, solves Ax=y directly. - """ - A = jnp.atleast_2d(A) - y = jnp.atleast_1d(y) - eps = jnp.sqrt(jnp.finfo(A.dtype).eps) - if A.shape[-2] == A.shape[-1]: - return jnp.linalg.solve(A, y) if y.size > 1 else jnp.squeeze(y / A) - elif A.shape[-2] > A.shape[-1]: - P = A.T @ A + eps * jnp.eye(A.shape[-1]) - return cho_solve(cho_factor(P), A.T @ y) - else: - P = A @ A.T + eps * jnp.eye(A.shape[-2]) - return A.T @ cho_solve(cho_factor(P), y) - - def _tangent_solve(g, y): - return _lstsq(jax.jacfwd(g)(y), y) - def root( fun, x0, @@ -512,7 +532,7 @@ def backtrack(xk1, fk1, d): def condfun(state): xk1, fk1, k1 = state - return (k1 < maxiter) & (jnp.dot(fk1, fk1) > tol**2) + return (k1 < maxiter) & (~_is_converged(fk1, tol)) def bodyfun(state): xk1, fk1, k1 = state @@ -570,15 +590,6 @@ def bodyfun(state): excluded={"eigvals_only", "select", "select_range", "tol"}, ) - def _map(f, xs, *, batch_size=None, in_axes=0, out_axes=0): - """Generalizes jax.lax.map; uses numpy.""" - if not isinstance(xs, np.ndarray): - raise NotImplementedError( - "Require numpy array input, or install jax to support pytrees." - ) - xs = np.moveaxis(xs, source=in_axes, destination=0) - return np.stack([f(x) for x in xs], axis=out_axes) - def vmap(fun, in_axes=0, out_axes=0): """A numpy implementation of jax.lax.map whose API is a subset of jax.vmap. diff --git a/desc/basis.py b/desc/basis.py index d81162f6a3..8b09dc0c6a 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -421,7 +421,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): else: lidx = loutidx = np.arange(len(modes)) if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((grid.num_nodes, modes.shape[0])) + return jnp.zeros((grid.num_nodes, self.num_modes)) if not len(modes): return np.array([]).reshape((grid.num_nodes, 0)) @@ -537,7 +537,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): else: nidx = noutidx = np.arange(len(modes)) if (derivatives[0] != 0) or (derivatives[1] != 0): - return jnp.zeros((grid.num_nodes, modes.shape[0])) + return jnp.zeros((grid.num_nodes, self.num_modes)) if not len(modes): return np.array([]).reshape((grid.num_nodes, 0)) @@ -609,6 +609,17 @@ def __init__(self, M, N, NFP=1, sym=False): self._spectral_indexing = "linear" self._modes = self._get_modes(M=self.M, N=self.N) super().__init__() + self._gauge_idx = np.asarray(self.get_idx(error=False)).squeeze() + + def _set_up(self): + """Do things after loading or changing resolution.""" + super()._set_up() + self._gauge_idx = np.asarray(self.get_idx(error=False)).squeeze() + + @property + def gauge_idx(self): + """ndarray: Index of constant-potential gauge mode, if present.""" + return self._gauge_idx def _get_modes(self, M, N): """Get mode numbers for double Fourier series. @@ -635,7 +646,13 @@ def _get_modes(self, M, N): z = np.zeros_like(m) return np.array([z, m, n]).T - def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): + def evaluate( + self, + grid, + derivatives=np.array([0, 0, 0]), + modes=None, + **kwargs, + ): """Evaluate basis functions at specified nodes. Parameters @@ -644,7 +661,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). - modes : ndarray of in, shape(num_modes,3), optional + modes : ndarray of int, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). Returns @@ -670,7 +687,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): midx = moutidx = np.arange(len(modes)) nidx = noutidx = np.arange(len(modes)) if derivatives[0] != 0: - return jnp.zeros((grid.num_nodes, modes.shape[0])) + return jnp.zeros((grid.num_nodes, self.num_modes)) if not len(modes): return np.array([]).reshape((grid.num_nodes, 0)) @@ -688,17 +705,18 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): _, t, z = grid.nodes.T _, m, n = modes.T - t = t[tidx] - z = z[zidx] + t = kwargs["t"] if "t" in kwargs else t[tidx] + z = kwargs["z"] if "z" in kwargs else z[zidx] m = m[midx] n = n[nidx] - poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1]) - toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2]) - poloidal = poloidal[toutidx][:, moutidx] - toroidal = toroidal[zoutidx][:, noutidx] - - return poloidal * toroidal + poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1])[:, moutidx] + toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2])[:, noutidx] + if grid.is_meshgrid and grid.num_rho == 1: + return (poloidal[:, np.newaxis] * toroidal[np.newaxis]).reshape( + -1, self.num_modes, order="F" + ) + return poloidal[toutidx] * toroidal[zoutidx] def change_resolution(self, M, N, NFP=None, sym=None): """Change resolution of the basis to the given resolutions. @@ -881,7 +899,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): lmidx = lmoutidx = np.arange(len(modes)) midx = moutidx = np.arange(len(modes)) if derivatives[2] != 0: - return jnp.zeros((grid.num_nodes, modes.shape[0])) + return jnp.zeros((grid.num_nodes, self.num_modes)) if not len(modes): return np.array([]).reshape((grid.num_nodes, 0)) @@ -1437,7 +1455,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): else: lidx = loutidx = np.arange(len(modes)) if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((grid.num_nodes, modes.shape[0])) + return jnp.zeros((grid.num_nodes, self.num_modes)) if not len(modes): return np.array([]).reshape((grid.num_nodes, 0)) diff --git a/desc/coils.py b/desc/coils.py index 42f96b1ce8..104f6f0740 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -1603,7 +1603,7 @@ def compute( Returns ------- - data : list of dict of ndarray + data : list[dict[str, jnp.ndarray]] Computed quantity and intermediate variables, for each coil in the set. List entries map to coils in coilset, each dict contains data for an individual coil. @@ -1843,6 +1843,7 @@ def body(AB, x): return AB, None AB += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0] + return AB AB = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz)) @@ -2753,7 +2754,7 @@ def compute( Returns ------- - data : list of dict of ndarray + data : list[dict[str, jnp.ndarray]] Computed quantity and intermediate variables, for each coil in the set. List entries map to coils in coilset, each dict contains data for an individual coil. diff --git a/desc/compute/__init__.py b/desc/compute/__init__.py index 343ca29002..5f6e169e84 100644 --- a/desc/compute/__init__.py +++ b/desc/compute/__init__.py @@ -35,6 +35,7 @@ _fast_ion, _field, _geometry, + _laplace, _metric, _neoclassical, _old, diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index a273b5c03a..4293518997 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -29,7 +29,7 @@ data=["B", "|B|"], ) def _b(params, transforms, profiles, data, **kwargs): - data["b"] = (data["B"].T / data["|B|"]).T + data["b"] = data["B"] / data["|B|"][:, jnp.newaxis] return data @@ -54,6 +54,29 @@ def _e_sup_rho(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="e_theta x e_zeta", + label="\\mathbf{e}_{\\theta} \\times \\mathbf{e}_{\\zeta}", + units="m^{2}", + units_long="square meters", + description="ρ surface area vector", + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e_theta", "e_zeta"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.surface.FourierRZToroidalSurface", + ], + aliases=["e^rho*sqrt(g)"], +) +def _e_theta_x_e_zeta(params, transforms, profiles, data, **kwargs): + data["e_theta x e_zeta"] = cross(data["e_theta"], data["e_zeta"]) + return data + + @register_compute_fun( name="e^rho_r", label="\\partial_{\\rho} \\mathbf{e}^{\\rho}", @@ -541,10 +564,11 @@ def _e_sup_rho_zz(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e^theta*sqrt(g)", "sqrt(g)"], + data=["e_zeta x e_rho", "sqrt(g)"], + aliases=["grad(theta)"], ) def _e_sup_theta(params, transforms, profiles, data, **kwargs): - data["e^theta"] = (data["e^theta*sqrt(g)"].T / data["sqrt(g)"]).T + data["e^theta"] = data["e_zeta x e_rho"] / data["sqrt(g)"][:, jnp.newaxis] return data @@ -678,11 +702,11 @@ def _e_sup_vartheta_p_PEST(params, transforms, profiles, data, **kwargs): @register_compute_fun( - name="e^theta*sqrt(g)", - label="\\mathbf{e}^{\\theta} \\sqrt{g}", + name="e_zeta x e_rho", + label="\\mathbf{e}_{\\zeta} \\times \\mathbf{e}_{\\rho}", units="m^{2}", units_long="square meters", - description="Contravariant poloidal basis vector weighted by 3-D volume Jacobian", + description="ΞΈ surface area vector", dim=3, params=[], transforms={}, @@ -692,11 +716,12 @@ def _e_sup_vartheta_p_PEST(params, transforms, profiles, data, **kwargs): parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", ], + aliases=["e^theta*sqrt(g)"], ) -def _e_sup_theta_times_sqrt_g(params, transforms, profiles, data, **kwargs): +def _e_zeta_x_e_rho(params, transforms, profiles, data, **kwargs): # At the magnetic axis, this function returns the multivalued map whose # image is the set { 𝐞^ΞΈ √g | ρ=0 }. - data["e^theta*sqrt(g)"] = cross(data["e_zeta"], data["e_rho"]) + data["e_zeta x e_rho"] = cross(data["e_zeta"], data["e_rho"]) return data @@ -1110,6 +1135,7 @@ def _e_sup_theta_zz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_rho", "e_theta/sqrt(g)"], + aliases=["grad(zeta)"], ) def _e_sup_zeta(params, transforms, profiles, data, **kwargs): # At the magnetic axis, this function returns the multivalued map whose @@ -3552,22 +3578,21 @@ def _gradpsi(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_zeta", "|e_theta x e_zeta|"], - axis_limit_data=["e_theta_r", "|e_theta x e_zeta|_r"], - parameterization=[ - "desc.equilibrium.equilibrium.Equilibrium", - ], + data=["e_theta x e_zeta", "|e_theta x e_zeta|"], + axis_limit_data=["e_theta_r", "e_zeta", "|e_theta x e_zeta|_r"], + parameterization=["desc.equilibrium.equilibrium.Equilibrium"], ) def _n_rho(params, transforms, profiles, data, **kwargs): # Equal to 𝐞^ρ / β€–πž^ρ‖ but works correctly for surfaces as well that don't # have contravariant basis defined. data["n_rho"] = transforms["grid"].replace_at_axis( - safediv(cross(data["e_theta"], data["e_zeta"]).T, data["|e_theta x e_zeta|"]).T, + safediv(data["e_theta x e_zeta"], data["|e_theta x e_zeta|"][:, jnp.newaxis]), # At the magnetic axis, this function returns the multivalued map whose # image is the set { 𝐞^ρ / β€–πž^ρ‖ | ρ=0 }. lambda: safediv( - cross(data["e_theta_r"], data["e_zeta"]).T, data["|e_theta x e_zeta|_r"] - ).T, + cross(data["e_theta_r"], data["e_zeta"]), + data["|e_theta x e_zeta|_r"][:, jnp.newaxis], + ), ) return data @@ -3583,18 +3608,15 @@ def _n_rho(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_zeta", "|e_theta x e_zeta|"], - parameterization=[ - "desc.geometry.surface.FourierRZToroidalSurface", - ], + data=["e^rho*sqrt(g)", "|e_theta x e_zeta|"], + parameterization=["desc.geometry.surface.FourierRZToroidalSurface"], ) def _n_rho_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs): # Equal to 𝐞^ρ / β€–πž^ρ‖ but works correctly for surfaces as well that don't # have contravariant basis defined. data["n_rho"] = safediv( - cross(data["e_theta"], data["e_zeta"]).T, data["|e_theta x e_zeta|"] - ).T - + data["e^rho*sqrt(g)"], data["|e_theta x e_zeta|"][:, jnp.newaxis] + ) return data @@ -3647,7 +3669,7 @@ def _n_rho_z(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_rho", "e_zeta", "|e_zeta x e_rho|"], + data=["e_zeta x e_rho", "|e_zeta x e_rho|"], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", ], @@ -3655,9 +3677,7 @@ def _n_rho_z(params, transforms, profiles, data, **kwargs): def _n_theta(params, transforms, profiles, data, **kwargs): # Equal to 𝐞^ΞΈ / β€–πž^ΞΈβ€– but works correctly for surfaces as well that don't # have contravariant basis defined. - data["n_theta"] = ( - cross(data["e_zeta"], data["e_rho"]).T / data["|e_zeta x e_rho|"] - ).T + data["n_theta"] = data["e_zeta x e_rho"] / data["|e_zeta x e_rho|"][:, jnp.newaxis] return data @@ -3683,12 +3703,16 @@ def _n_zeta(params, transforms, profiles, data, **kwargs): # Equal to 𝐞^ΞΆ / β€–πž^ΞΆβ€– but works correctly for surfaces as well that don't # have contravariant basis defined. data["n_zeta"] = transforms["grid"].replace_at_axis( - safediv(cross(data["e_rho"], data["e_theta"]).T, data["|e_rho x e_theta|"]).T, + safediv( + cross(data["e_rho"], data["e_theta"]), + data["|e_rho x e_theta|"][:, jnp.newaxis], + ), # At the magnetic axis, this function returns the multivalued map whose # image is the set { 𝐞^ΞΆ / β€–πž^ΞΆβ€– | ρ=0 }. lambda: safediv( - cross(data["e_rho"], data["e_theta_r"]).T, data["|e_rho x e_theta|_r"] - ).T, + cross(data["e_rho"], data["e_theta_r"]), + data["|e_rho x e_theta|_r"][:, jnp.newaxis], + ), ) return data @@ -3989,6 +4013,79 @@ def _e_alpha_rp_norm(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="n_rho x grad(theta)", + label="\\Vert \\mathbf{e}^{\\rho} \\Vert^{-1} \\mathbf{e}^{\\rho} " + "\times \\mathbf{e}^{\\theta}", + units="m^{-1}", + units_long="inverse meters", + description="Rotated surface gradient of poloidal angle.", + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e_zeta", "|e_theta x e_zeta|"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.surface.FourierRZToroidalSurface", + ], +) +def _surface_gradient_theta(params, transforms, profiles, data, **kwargs): + data["n_rho x grad(theta)"] = ( + data["e_zeta"] / data["|e_theta x e_zeta|"][:, jnp.newaxis] + ) + return data + + +@register_compute_fun( + name="n_rho x grad(zeta)", + label="\\Vert \\mathbf{e}^{\\rho} \\Vert^{-1} \\mathbf{e}^{\\rho} " + "\times \\mathbf{e}^{\\zeta}", + units="m^{-1}", + units_long="inverse meters", + description="Rotated surface gradient of toroidal angle.", + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e_theta", "|e_theta x e_zeta|"], + axis_limit_data=["e_theta_r", "|e_theta x e_zeta|_r"], + parameterization=["desc.equilibrium.equilibrium.Equilibrium"], +) +def _surface_gradient_zeta(params, transforms, profiles, data, **kwargs): + data["n_rho x grad(zeta)"] = transforms["grid"].replace_at_axis( + safediv(-data["e_theta"], data["|e_theta x e_zeta|"][:, jnp.newaxis]), + lambda: -data["e_theta_r"] / data["|e_theta x e_zeta|_r"][:, jnp.newaxis], + ) + return data + + +@register_compute_fun( + name="n_rho x grad(zeta)", + label="\\Vert \\mathbf{e}^{\\rho} \\Vert^{-1} \\mathbf{e}^{\\rho} " + "\times \\mathbf{e}^{\\zeta}", + units="m^{-1}", + units_long="inverse meters", + description="Rotated surface gradient of toroidal angle.", + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e_theta", "|e_theta x e_zeta|"], + parameterization=["desc.geometry.surface.FourierRZToroidalSurface"], +) +def _surface_gradient_zeta_FourierRZToroidalSurface( + params, transforms, profiles, data, **kwargs +): + data["n_rho x grad(zeta)"] = safediv( + -data["e_theta"], data["|e_theta x e_zeta|"][:, jnp.newaxis] + ) + return data + + ################################################################################## ##########---------------HIGHER-ORDER DERIVATIVES (PEST)---------------########### ################################################################################## diff --git a/desc/compute/_core.py b/desc/compute/_core.py index 04c60e00da..e0de60b0f6 100644 --- a/desc/compute/_core.py +++ b/desc/compute/_core.py @@ -3180,7 +3180,7 @@ def _phi_zzz(params, transforms, profiles, data, **kwargs): ], ) def _rho(params, transforms, profiles, data, **kwargs): - data["rho"] = transforms["grid"].nodes[:, 0] + data["rho"] = jnp.asarray(transforms["grid"].nodes[:, 0]) return data @@ -3202,7 +3202,7 @@ def _rho(params, transforms, profiles, data, **kwargs): ], ) def _theta(params, transforms, profiles, data, **kwargs): - data["theta"] = transforms["grid"].nodes[:, 1] + data["theta"] = jnp.asarray(transforms["grid"].nodes[:, 1]) return data @@ -3528,5 +3528,5 @@ def _theta_PEST_ttz(params, transforms, profiles, data, **kwargs): ], ) def _zeta(params, transforms, profiles, data, **kwargs): - data["zeta"] = transforms["grid"].nodes[:, 2] + data["zeta"] = jnp.asarray(transforms["grid"].nodes[:, 2]) return data diff --git a/desc/compute/_laplace.py b/desc/compute/_laplace.py new file mode 100644 index 0000000000..b56224433f --- /dev/null +++ b/desc/compute/_laplace.py @@ -0,0 +1,1137 @@ +"""Compute functions for multiply connected Laplace solver as described in [1]_. + +References +---------- +.. [1] Unalmis et al. New high-order accurate free surface stellarator + equilibria optimization and boundary integral methods in DESC. + +""" + +from functools import partial +from typing import NamedTuple, Optional + +import equinox as eqx +import jax +import lineax as lx +import optimistix as optx +from interpax_fft import rfft_interp2d + +from desc.backend import jnp +from desc.integrals.singularities import ( + _kernel_BS_plus_grad_S, + _kernel_dipole, + _kernel_dipole_plus_half, + _kernel_monopole, + _nonsingular_part, + _prune_data, + get_interpolator, + singular_integral, +) +from desc.utils import cross, dot, errorif + +from .data_index import register_compute_fun + + +class Options(NamedTuple): + """Laplace solver options.""" + + Phi_0: Optional[jax.Array] = None + """Initial guess for iteration.""" + + atol: float = 1e-7 + """Absolute error tolerance for the iterative linear solve. Default is ``1e-7``.""" + + rtol: float = 1e-6 + """Relative error tolerance for the iterative linear solve. Default is ``1e-6``.""" + + max_steps: int = 10 + """Maximum number of steps for iterative linear solve. + + Typically converges in 2 iterations. Default max value is ``10``. + """ + + problem: str = "interior Neumann" + """Boundary value problem to solve. + + One of ``"interior Neumann"``, ``"exterior Neumann"``, or ``"interior Dirichlet"``. + (In some routines this may be determined automatically.) + """ + + solve_method: str = "auto" + """Method to use for the scalar potential solve. + + One of ``"auto"``, ``"gmres"``, or ``"direct"``. If ``"auto"``, then uses + GMRES when the problem supports it, otherwise uses the direct solve. Default + is ``"auto"``. If GMRES errors due to incompatibility with old JAX versions, + ``"fixed_point"`` can be selected instead. + """ + + full_output: bool = False + """Whether to return diagnostic output of the iterative potential solve. + + If ``True``, computes the maximum error ``Phi error`` and stores the number + of steps ``num_steps`` used by the scalar potential solver. Default is + ``False``. + """ + + chunk_size: Optional[int] = None + """Size to split integral computation into chunks. + + If no chunking should be done or the chunk size is the full input then + supply ``None``. Default is ``None``. Recommend to verify computation with + ``chunk_size`` set to a small number due to bugs in JAX or XLA. + """ + + B_coil_chunk_size: Optional[int] = None + """Size to split coil integral computation into chunks. + + If no chunking should be done or the chunk size is the full input then + supply ``None``. Default is ``None``. + """ + + D_quad: bool = False + """Developer option for double-layer potential quadrature. + + Set to ``True`` to perform double-layer potential quadrature without removing + singularities. Default is ``False``. + """ + + @staticmethod + def select_solver(options): + """Pick the solver based on the problem.""" + solve_method = options.solve_method + is_interior_neumann = options.problem == "interior Neumann" + if solve_method == "auto": + solve_method = "direct" if is_interior_neumann else "gmres" + errorif( + solve_method not in {"fixed_point", "gmres", "direct"}, + msg="solve_method must be one of 'auto', 'fixed_point', 'gmres', " + f"or 'direct', got {solve_method!r}.", + ) + errorif( + solve_method in {"fixed_point", "gmres"} and is_interior_neumann, + msg=f"solve_method={solve_method!r} is not supported for interior Neumann " + "problems. Use solve_method='direct' instead.", + ) + return options._replace(solve_method=solve_method) + + +def _D_plus_half( + eval_data, + source_data, + interpolator, + basis=None, + chunk_size=None, + prune_data=True, + _D_quad=False, +): + """Compute (D[Ξ¦] + Ξ¦/2)(x). + + D[Ξ¦](x) = ∫_y Ξ¦(y)γ€ˆβˆ‡_x G(xβˆ’y),ds(y)〉. + + Parameters + ---------- + basis : DoubleFourierSeries + If not supplied, then computes (D[Ξ¦] + Ξ¦/2)(x). + If supplied, then constructs the operator which + acts on the spectral coefficients of Ξ¦ in the supplied + secular basis. + prune_data : bool + Whether the data should be pruned. Default is True. + _D_quad : bool + Set to ``True`` to perform double layer potential quadrature without removing + singularities. Default is ``False``. This is intended for developer use. + + """ + if basis is None: + ndim = 1 + known_map = None + else: + ndim = basis.num_modes + known_map = ("Phi (periodic)", basis.evaluate) + + kernel = _kernel_dipole if _D_quad else _kernel_dipole_plus_half + + result = singular_integral( + eval_data, + source_data, + interpolator, + kernel, + known_map=known_map, + ndim=ndim, + chunk_size=chunk_size, + _prune_data=prune_data, + ) + if ndim == 1: + result = result.squeeze(-1) + + if _D_quad: + result += eval_data["Phi(x) (periodic)"] / 2 + + return result + + +@eqx.filter_jit +def _direct_solve( + boundary_condition, potential_data, source_data, interpolator, basis, options +): + potential_grid = interpolator.eval_grid + source_grid = interpolator.source_grid + + assert basis.M <= potential_grid.M + assert basis.N <= potential_grid.N + well_posed = potential_grid.num_nodes == basis.num_modes + if not well_posed: + well_posed = None + + potential_data, source_data = _prune_data( + potential_data, + potential_grid, + source_data, + source_grid, + _kernel_dipole_plus_half, + ) + Phi = basis.evaluate(potential_grid) + potential_data["Phi(x) (periodic)"] = Phi + source_data["Phi (periodic)"] = ( + Phi + if ( + potential_grid.num_theta == source_grid.num_theta + and potential_grid.num_zeta == source_grid.num_zeta + ) + else basis.evaluate(source_grid) + ) + + D = _D_plus_half( + potential_data, + source_data, + interpolator, + basis, + options.chunk_size, + prune_data=False, + _D_quad=options.D_quad, + ) + assert D.shape == (potential_grid.num_nodes, basis.num_modes) + + insert_gauge = False + if options.problem in ("exterior Neumann", "interior Dirichlet"): + # This system is negative definite, but perhaps not symmetric. + # Lineax assumes negative semidefinite means the operator is symmetric. + # Hence we do not set that tag. + D -= Phi + elif options.problem == "interior Neumann" and basis.gauge_idx.size: + # This system is positive definite, but the same logic above applies. + if well_posed: + D = D.at[-1].set(0.0).at[-1, basis.gauge_idx].set(1.0) + boundary_condition = boundary_condition.at[-1].set(0.0) + else: + D = jnp.delete(D, basis.gauge_idx, axis=1, assume_unique_indices=True) + insert_gauge = True + + D = lx.MatrixLinearOperator(D) + Phi_mn = lx.linear_solve( + D, boundary_condition, solver=lx.AutoLinearSolver(well_posed=well_posed) + ).value + if insert_gauge: + Phi_mn = jnp.insert(Phi_mn, basis.gauge_idx, 0.0) + + return Phi_mn + + +@eqx.filter_jit +def _iterative_solve( + boundary_condition, potential_data, source_data, interpolator, options +): + potential_grid = interpolator.eval_grid + source_grid = interpolator.source_grid + + potential_data, source_data = _prune_data( + potential_data, + potential_grid, + source_data, + source_grid, + _kernel_dipole_plus_half, + ) + Phi_0 = options.Phi_0 + if Phi_0 is None: + Phi_0 = jnp.ones(potential_grid.num_nodes) + assert Phi_0.size == potential_grid.num_nodes + + if options.solve_method == "gmres": + operator = lx.FunctionLinearOperator( + partial( + _linear_potential_operator, + potential_data=potential_data, + source_data=source_data, + interpolator=interpolator, + chunk_size=options.chunk_size, + ), + jax.ShapeDtypeStruct(Phi_0.shape, Phi_0.dtype), + ) + solution = lx.linear_solve( + operator, + boundary_condition, + solver=lx.GMRES( + rtol=options.rtol, + atol=options.atol, + max_steps=options.max_steps, + ), + options={"y0": Phi_0}, + throw=False, + ) + if options.full_output: + err = jnp.abs(operator.mv(solution.value) - boundary_condition).max() + return solution.value, (err, solution.stats["num_steps"]) + return solution.value + + # Some JAX versions fail to transpose scan, so we keep fixed point. + xi = 2 / 3 + args = ( + boundary_condition, + potential_data, + source_data, + interpolator, + options.chunk_size, + xi, + ) + solution = optx.fixed_point( + _iteration_operator, + optx.FixedPointIteration(rtol=options.rtol, atol=options.atol), + Phi_0, + args, + max_steps=options.max_steps, + adjoint=optx.ImplicitAdjoint( + lx.GMRES( + rtol=options.rtol, + atol=options.atol, + max_steps=options.max_steps, + ) + ), + throw=False, + ) + if options.full_output: + err = jnp.abs(_iteration_operator(solution.value, args) - solution.value).max() + return solution.value, (err, solution.stats["num_steps"]) + return solution.value + + +def _iteration_operator(Phi, args): + """Equation 3.12 in [1]_.""" + gamma, potential_data, source_data, interpolator, chunk_size, xi = args + potential_data["Phi(x) (periodic)"] = Phi + source_data["Phi (periodic)"] = _interp( + Phi, interpolator.eval_grid, interpolator.source_grid + ) + return ( + _D_plus_half( + potential_data, + source_data, + interpolator, + chunk_size=chunk_size, + prune_data=False, + ) + + (xi - 1) * Phi + - gamma + ) / xi + + +def _linear_potential_operator( + Phi, potential_data, source_data, interpolator, chunk_size +): + """Equation solved by the iterative linear solver.""" + potential_data["Phi(x) (periodic)"] = Phi + source_data["Phi (periodic)"] = _interp( + Phi, interpolator.eval_grid, interpolator.source_grid + ) + return ( + _D_plus_half( + potential_data, + source_data, + interpolator, + chunk_size=chunk_size, + prune_data=False, + ) + - Phi + ) + + +def _interp(x, input_grid, output_grid): + if ( + input_grid.num_theta == output_grid.num_theta + and input_grid.num_zeta == output_grid.num_zeta + ): + return x + return rfft_interp2d( + input_grid.meshgrid_reshape(x, "rtz")[0], + output_grid.num_theta, + output_grid.num_zeta, + dx=2 * jnp.pi / input_grid.num_theta, + dy=2 * jnp.pi / input_grid.num_zeta / input_grid.NFP, + ).ravel(order="F") + + +@register_compute_fun( + name="interpolator", + label="", + units="", + units_long="", + description="Interpolator for singular integrals.", + dim=1, + coordinates="tz", + params=[], + transforms={"grid": []}, + profiles=[], + data=["|e_theta x e_zeta|", "e_theta", "e_zeta"], + parameterization=["desc.geometry.surface.FourierRZToroidalSurface"], + q="int : Order of quadrature in polar domain.", + potential_grid="""LinearGrid : + Grid to evaluate potential on boundary. + If not given, default is to interpolate to source grid. + """, + warn_fft="""bool : + Whether to warn if the interpolation will be lossy. Default is ``True``. + """, +) +def _interpolator(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + grid = transforms["grid"] + potential_grid = kwargs.get("potential_grid", grid) + data["interpolator"] = get_interpolator(potential_grid, grid, data, **kwargs) + + # TODO: interpolate Rb_mn, Zb_mn, and omegab_mn directly + data["potential data"] = { + "R": _interp(data["R"], grid, potential_grid), + "omega": _interp(data["omega"], grid, potential_grid), + "Z": _interp(data["Z"], grid, potential_grid), + } + zeta = potential_grid.nodes[:, 2] + data["potential data"]["phi"] = zeta + data["potential data"]["omega"] + + return data + + +@register_compute_fun( + name="potential data", + label="potential data", + units="~", + units_long="not applicable", + description="RpZ position on the potential grid", + dim=1, + coordinates="rtz", + params=[], + transforms={}, + profiles=[], + data=["interpolator"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + public=False, +) +def _potential_grid_position(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + return data + + +@register_compute_fun( + name="S[B0*n]", + label="S[B_0 \\cdot n_{\\rho}]", + units="T m", + units_long="Tesla meter", + description="Single layer potential of monopole density B0*n", + dim=1, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=_kernel_monopole.keys + ["interpolator"], + resolution_requirement="tz", + grid_requirement={"can_fft2": True}, + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + options=Options.__doc__, + public=False, +) +def _S_B0_n(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + options = kwargs.get("options", Options()) + data["S[B0*n]"] = singular_integral( + data.get("potential data", data), + data, + data["interpolator"], + _kernel_monopole, + chunk_size=options.chunk_size, + ).squeeze(-1) + return data + + +@register_compute_fun( + name="Phi_mn", + label="\\Phi_{m n}", + units="T m", + units_long="Tesla meter", + description="Fourier coefficients of periodic part of potential", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi": [[0, 0, 0]]}, + profiles=[], + data=list(set(_kernel_dipole_plus_half.keys) - {"Phi (periodic)"}) + + ["S[B0*n]", "interpolator"], + resolution_requirement="tz", + grid_requirement={"can_fft2": True}, + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + options=Options.__doc__, +) +def _scalar_potential_mn_Neumann(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + options = Options.select_solver(kwargs.get("options", Options())) + + if options.solve_method == "direct": + data["Phi_mn"] = _direct_solve( + data["S[B0*n]"], + data.get("potential data", data), + data, + data["interpolator"], + transforms["Phi"].basis, + options, + ) + else: + data["Phi (periodic)"] = _iterative_solve( + data["S[B0*n]"], + data.get("potential data", data), + data, + data["interpolator"], + options, + ) + if options.full_output: + data["Phi (periodic)"], (data["Phi error"], data["num_steps"]) = data[ + "Phi (periodic)" + ] + + assert data["Phi (periodic)"].size == transforms["Phi"].grid.num_nodes + data["Phi_mn"] = transforms["Phi"].fit(data["Phi (periodic)"]) + return data + + +@register_compute_fun( + name="Phi (periodic)", + label="\\Phi", + units="T m", + units_long="Tesla meter", + description="Periodic part of magnetic scalar potential", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi": [[0, 0, 0]]}, + profiles=[], + data=["Phi_mn"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _Phi_periodic_potential(params, transforms, profiles, data, **kwargs): + assert data["Phi_mn"].size == transforms["Phi"].basis.num_modes + data["Phi (periodic)"] = transforms["Phi"].transform(data["Phi_mn"]) + return data + + +@register_compute_fun( + name="Phi_t (periodic)", + label="\\partial_{\\theta} \\Phi_{\\text{periodic}}", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential, poloidal derivative", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi": [[0, 1, 0]]}, + profiles=[], + data=["Phi_mn"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _pot_Phi_t_periodic(params, transforms, profiles, data, **kwargs): + assert data["Phi_mn"].size == transforms["Phi"].basis.num_modes + data["Phi_t (periodic)"] = transforms["Phi"].transform(data["Phi_mn"], dt=1) + return data + + +@register_compute_fun( + name="Phi_z (periodic)", + label="\\partial_{\\zeta} \\Phi_{\\text{periodic}}", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential, toroidal derivative", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi": [[0, 0, 1]]}, + profiles=[], + data=["Phi_mn"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _pot_Phi_z_periodic(params, transforms, profiles, data, **kwargs): + assert data["Phi_mn"].size == transforms["Phi"].basis.num_modes + data["Phi_z (periodic)"] = transforms["Phi"].transform(data["Phi_mn"], dz=1) + return data + + +@register_compute_fun( + name="K_vc (periodic)", + label="-n \\times \\nabla \\Phi_{\\text{periodic}}", + units="T", + units_long="Tesla", + description="Virtual surface current due to potential", + dim=3, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=[ + "n_rho x grad(theta)", + "n_rho x grad(zeta)", + "Phi_t (periodic)", + "Phi_z (periodic)", + ], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _virtual_surface_current_periodic(params, transforms, profiles, data, **kwargs): + data["K_vc (periodic)"] = -( + data["Phi_t (periodic)"][:, None] * data["n_rho x grad(theta)"] + + data["Phi_z (periodic)"][:, None] * data["n_rho x grad(zeta)"] + ) + return data + + +@register_compute_fun( + name="Phi", + label="\\Phi", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential", + dim=1, + coordinates="tz", + params=["I", "Y"], + transforms={}, + profiles=[], + data=["Phi (periodic)", "theta", "zeta"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _Phi_scalar_potential(params, transforms, profiles, data, **kwargs): + data["Phi"] = ( + data["Phi (periodic)"] + + params["I"] * data["theta"] + + params["Y"] * data["zeta"] + ) + return data + + +@register_compute_fun( + name="Phi error", + label="\\Phi_{\\text{error}}", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential error", + dim=0, + coordinates="", + params=[], + transforms={}, + profiles=[], + data=["Phi_mn"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + public=False, +) +def _Phi_error(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + return data + + +@register_compute_fun( + name="num_steps", + label="\\text{number of steps}", + units="", + units_long="", + description="Magnetic scalar potential number of steps for inversion", + dim=0, + coordinates="", + params=[], + transforms={}, + profiles=[], + data=["Phi_mn"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + public=False, +) +def _Phi_num_steps(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + return data + + +@register_compute_fun( + name="Phi_t", + label="\\partial_{\\theta} \\Phi", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential, poloidal derivative", + dim=1, + coordinates="tz", + params=["I"], + transforms={}, + profiles=[], + data=["Phi_t (periodic)"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _pot_Phi_t(params, transforms, profiles, data, **kwargs): + data["Phi_t"] = data["Phi_t (periodic)"] + params["I"] + return data + + +@register_compute_fun( + name="Phi_z", + label="\\partial_{\\zeta} \\Phi", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential, toroidal derivative", + dim=1, + coordinates="tz", + params=["Y"], + transforms={}, + profiles=[], + data=["Phi_z (periodic)"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _pot_Phi_z(params, transforms, profiles, data, **kwargs): + data["Phi_z"] = data["Phi_z (periodic)"] + params["Y"] + return data + + +@register_compute_fun( + name="K_vc", + label="-n \\times \\nabla \\Phi", + units="T", + units_long="Tesla", + description="Virtual surface current due to potential", + dim=3, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=["n_rho x grad(theta)", "n_rho x grad(zeta)", "Phi_t", "Phi_z"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _virtual_surface_current(params, transforms, profiles, data, **kwargs): + data["K_vc"] = -( + data["Phi_t"][:, None] * data["n_rho x grad(theta)"] + + data["Phi_z"][:, None] * data["n_rho x grad(zeta)"] + ) + return data + + +@register_compute_fun( + name="|K_vc|^2", + label="\\vert K_{\\text{vc}}) \\vert^2", + units="T^2", + units_long="Tesla squared", + description="Squared norm of virtual surface current", + dim=1, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=["K_vc"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _K_vc_squared(params, transforms, profiles, data, **kwargs): + data["|K_vc|^2"] = dot(data["K_vc"], data["K_vc"]) + return data + + +@register_compute_fun( + name="βˆ‡Ο†", + label="\\nabla \\varphi", + units="T", + units_long="Tesla", + description="Magnetic field due to potential which solves the" + " boundary value problem.", + dim=3, + coordinates="RpZ", + params=[], + transforms={}, + profiles=[], + data=_kernel_BS_plus_grad_S.keys, + resolution_requirement="tz", + grid_requirement={"can_fft2": True}, + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + options=Options.__doc__, + eval_interpolator="""_BIESTInterpolator : + Interpolator from source grid to evaluation grid on boundary. + If not given, default is to interpolate to source grid. + """, + on_boundary="bool : Whether RpZcoords are on boundary surface.", + public=False, +) +def _grad_potential(params, transforms, profiles, data, RpZ_data, **kwargs): + # noqa: unused dependency + options = kwargs.get("options", Options()) + sign = 1 - 2 * int("exterior" in options.problem) + + if kwargs["on_boundary"]: + RpZ_data["βˆ‡Ο†"] = ( + sign + * 2 + * singular_integral( + RpZ_data, + data, + kwargs.get("eval_interpolator", data.get("interpolator", None)), + _kernel_BS_plus_grad_S, + chunk_size=options.chunk_size, + ) + ) + else: + grid = transforms["grid"] + eval_data, source_data = _prune_data( + RpZ_data, None, data, grid, _kernel_BS_plus_grad_S + ) + RpZ_data["βˆ‡Ο†"] = sign * _nonsingular_part( + eval_data, + None, + source_data, + grid, + st=jnp.nan, + sz=jnp.nan, + kernel=_kernel_BS_plus_grad_S, + chunk_size=options.chunk_size, + ) + return RpZ_data + + +@register_compute_fun( + name="B", + label="B", + units="T", + units_long="Tesla", + description="Magnetic field", + dim=3, + coordinates="RpZ", + params=[], + transforms={}, + profiles=[], + data=["βˆ‡Ο†", "B0"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _total_B(params, transforms, profiles, data, RpZ_data, **kwargs): + RpZ_data["B"] = RpZ_data["βˆ‡Ο†"] + RpZ_data["B0"] + return RpZ_data + + +@register_compute_fun( + name="B0*n", + label="B_0 \\cdot n_{\\rho}", + units="T", + units_long="Tesla", + description="Auxillary field dotted into flux surface normal", + dim=1, + coordinates="tz", + params=[], + transforms={"grid": []}, + profiles=[], + data=["x", "n_rho"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + B0="_MagneticField : Field object to compute with.", + field_grid="Grid : Source grid used to compute magnetic field.", + options=Options.__doc__, +) +def _B0_dot_n(params, transforms, profiles, data, **kwargs): + options = kwargs.get("options", Options()) + data["B0*n"] = dot( + kwargs["B0"].compute_magnetic_field( + coords=data["x"], + source_grid=kwargs.get("field_grid", None), + chunk_size=options.chunk_size, + ), + data["n_rho"], + ) + return data + + +@register_compute_fun( + name="B0", + label="B0", + units="T", + units_long="Tesla", + description="Auxillary field", + dim=3, + coordinates="RpZ", + params=[], + transforms={"grid": []}, + profiles=[], + data=[], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + B0="_MagneticField : Field object to compute with.", + field_grid="Grid : Source grid used to compute magnetic field.", + options=Options.__doc__, + public=False, +) +def _B0_field(params, transforms, profiles, data, RpZ_data, **kwargs): + options = kwargs.get("options", Options()) + coords = jnp.column_stack([RpZ_data["R"], RpZ_data["phi"], RpZ_data["Z"]]) + RpZ_data["B0"] = kwargs["B0"].compute_magnetic_field( + coords=coords, + source_grid=kwargs.get("field_grid", None), + chunk_size=options.chunk_size, + ) + return RpZ_data + + +@register_compute_fun( + name="B_coil", + label="B_{\\text{coil}}", + units="T", + units_long="Tesla", + description="Magnetic field due to coils", + dim=3, + coordinates="rtz", + params=[], + transforms={"grid": []}, + profiles=[], + data=["x"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", + options=Options.__doc__, + B_coil="_MagneticField : Field object to compute with.", + field_grid="Grid : Source grid used to compute magnetic field.", +) +def _B_coil_field(params, transforms, profiles, data, **kwargs): + options = kwargs.get("options", Options()) + data["B_coil"] = kwargs["B_coil"].compute_magnetic_field( + coords=data["x"], + source_grid=kwargs.get("field_grid", None), + chunk_size=options.B_coil_chunk_size, + ) + return data + + +@register_compute_fun( + name="n_rho x B_coil", + label="n_{\\rho} \\times B_{\\text{coil}}", + units="T", + units_long="Tesla", + description="Flux surface normal cross magnetic field due to coils", + dim=1, + coordinates="rtz", + params=[], + transforms={}, + profiles=[], + data=["n_rho", "B_coil"], + parameterization="desc.magnetic_fields._laplace.SourceFreeField", +) +def _n_rho_x_B_coil(params, transforms, profiles, data, **kwargs): + data["n_rho x B_coil"] = cross(data["n_rho"], data["B_coil"]) + return data + + +@register_compute_fun( + name="Y_coil", + label="Y_{\\text{coil}}", + units="T m", + units_long="Tesla meter", + description="Net poloidal current produced by magnetic coils", + dim=0, + coordinates="", + params=["Y"], + transforms={}, + profiles=[], + data=["e_zeta", "B_coil"], + options=Options.__doc__, + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", +) +def _Y_coil(params, transforms, profiles, data, **kwargs): + if params.get("Y", None) is not None: + data["Y_coil"] = params["Y"] + return data + # Equation B.2 in [1]_. + data["Y_coil"] = dot(data["B_coil"], data["e_zeta"]).mean() + return data + + +@register_compute_fun( + name="Phi_coil_mn", + label="\\Phi_{\\text{coil}, mn}", + units="T m", + units_long="Tesla meter", + description="Fourier coefficients of periodic part of coil scalar potential", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi_coil": [[0, 0, 0]]}, + profiles=[], + data=["n_rho x B_coil", "n_rho x grad(theta)", "n_rho x grad(zeta)", "Y_coil"], + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", +) +def _Phi_mn_coil(params, transforms, profiles, data, **kwargs): + """Returns coil potential harmonics. + + ``B_coil`` must be smooth and divergence free for correctness of inversion. + TODO: Compute this from scalar potential integral, without inversion. + """ + grid = transforms["grid"] + assert grid.num_rho == 1 + + basis = transforms["Phi_coil"].basis + # TODO: could compute these in objective build + # and avoid computing if they are passed in as kwargs + _t = basis.evaluate(grid, [0, 1, 0])[:, None] + _z = basis.evaluate(grid, [0, 0, 1])[:, None] + + mat = ( + _t * data["n_rho x grad(theta)"][..., None] + + _z * data["n_rho x grad(zeta)"][..., None] + ).reshape(grid.num_nodes * 3, basis.num_modes) + if basis.gauge_idx.size: + mat = jnp.delete(mat, basis.gauge_idx, axis=1, assume_unique_indices=True) + mat = lx.MatrixLinearOperator(mat) + + # Equation 5.16 in [1]_. + Phi_coil_mn = lx.linear_solve( + mat, + (data["n_rho x B_coil"] - data["Y_coil"] * data["n_rho x grad(zeta)"]).ravel(), + solver=lx.AutoLinearSolver(well_posed=None), + ).value + if basis.gauge_idx.size: + Phi_coil_mn = jnp.insert(Phi_coil_mn, basis.gauge_idx, 0.0) + + data["Phi_coil_mn"] = Phi_coil_mn + return data + + +@register_compute_fun( + name="Phi_coil (periodic)", + label="(n \\times \\nabla)^{-1} (n \\times B_{\\text{coil}})", + units="T m", + units_long="Tesla meter", + description="Periodic part of magnetic scalar potential of coil field", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi_coil": [[0, 0, 0]]}, + profiles=[], + data=["Phi_coil_mn"], + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", +) +def _Phi_coil_periodic(params, transforms, profiles, data, **kwargs): + data["Phi_coil (periodic)"] = transforms["Phi_coil"].transform(data["Phi_coil_mn"]) + return data + + +@register_compute_fun( + name="Phi_coil (secular)", + label="(n \\times \\nabla)^{-1} (n \\times B_{\\text{coil}})", + units="T m", + units_long="Tesla meter", + description="Secular part of magnetic scalar potential of coil field", + dim=1, + coordinates="z", + params=[], + transforms={}, + profiles=[], + data=["zeta", "Y_coil"], + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", +) +def _Phi_coil_secular(params, transforms, profiles, data, **kwargs): + data["Phi_coil (secular)"] = data["Y_coil"] * data["zeta"] + return data + + +@register_compute_fun( + name="Phi_coil", + label="(n \\times \\nabla)^{-1} (n \\times B_{\\text{coil}})", + units="T m", + units_long="Tesla meter", + description="Magnetic scalar potential of coil field", + dim=1, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=["Phi_coil (periodic)", "Phi_coil (secular)"], + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", +) +def _Phi_coil(params, transforms, profiles, data, **kwargs): + data["Phi_coil"] = data["Phi_coil (periodic)"] + data["Phi_coil (secular)"] + return data + + +@register_compute_fun( + name="Phi_mn", + label="\\Phi_{m n}", + units="T m", + units_long="Tesla meter", + description="Fourier coefficients of periodic part of potential", + dim=1, + coordinates="tz", + params=[], + transforms={"Phi": [[0, 0, 0]]}, + profiles=[], + data=list(set(_kernel_dipole_plus_half.keys) - {"Phi (periodic)"}) + + ["Phi_coil (periodic)", "S[B0*n]", "interpolator"], + resolution_requirement="tz", + grid_requirement={"can_fft2": True}, + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", + options=Options.__doc__, +) +def _scalar_potential_mn_free_surface(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + options = Options.select_solver( + kwargs.get("options", Options())._replace(problem="interior Dirichlet") + ) + + boundary_condition = data["S[B0*n]"] - data["Phi_coil (periodic)"] + if options.solve_method == "direct": + data["Phi_mn"] = _direct_solve( + boundary_condition, + data.get("potential data", data), + data, + data["interpolator"], + transforms["Phi"].basis, + options, + ) + else: + data["Phi (periodic)"] = _iterative_solve( + boundary_condition, + data.get("potential data", data), + data, + data["interpolator"], + options, + ) + if options.full_output: + data["Phi (periodic)"], (data["Phi error"], data["num_steps"]) = data[ + "Phi (periodic)" + ] + + assert data["Phi (periodic)"].size == transforms["Phi"].grid.num_nodes + data["Phi_mn"] = transforms["Phi"].fit(data["Phi (periodic)"]) + return data + + +@register_compute_fun( + name="Ξ³ potential", + label="\\gamma", + units="T m", + units_long="Tesla meter", + description="Double layer potential with dipole density -Ξ¦", + dim=1, + coordinates="tz", + params=[], + transforms={}, + profiles=[], + data=_kernel_dipole_plus_half.keys + ["interpolator"], + resolution_requirement="tz", + grid_requirement={"can_fft2": True}, + parameterization="desc.magnetic_fields._laplace.FreeSurfaceOuterField", + options=Options.__doc__, + public=False, +) +def _gamma_potential(params, transforms, profiles, data, **kwargs): + # noqa: unused dependency + options = kwargs.get("options", Options()) + data["Phi(x) (periodic)"] = data["Phi (periodic)"] + # Left hand side of equation 5.15 in [1]_ computed by evaluating + # the right hand side. This is used for testing. + data["Ξ³ potential"] = data["Phi (periodic)"] - _D_plus_half( + data, + data, + data["interpolator"], + chunk_size=options.chunk_size, + ) + return data diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index a9ff9685a0..e51a3a3260 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -29,10 +29,10 @@ transforms={}, profiles=[], coordinates="rtz", - data=["e_rho", "e_theta", "e_zeta"], + data=["e_rho", "e_theta x e_zeta"], ) def _sqrtg(params, transforms, profiles, data, **kwargs): - data["sqrt(g)"] = dot(data["e_rho"], cross(data["e_theta"], data["e_zeta"])) + data["sqrt(g)"] = dot(data["e_rho"], data["e_theta x e_zeta"]) return data @@ -92,16 +92,14 @@ def _sqrtg_clebsch(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_zeta"], + data=["e_theta x e_zeta"], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", "desc.geometry.surface.FourierRZToroidalSurface", ], ) -def _e_theta_x_e_zeta(params, transforms, profiles, data, **kwargs): - data["|e_theta x e_zeta|"] = safenorm( - cross(data["e_theta"], data["e_zeta"]), axis=-1 - ) +def _e_theta_x_e_zeta_norm(params, transforms, profiles, data, **kwargs): + data["|e_theta x e_zeta|"] = safenorm(data["e_theta x e_zeta"], axis=-1) return data @@ -117,13 +115,13 @@ def _e_theta_x_e_zeta(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_zeta", "e_theta_r", "e_zeta_r"], + data=["e_theta", "e_zeta", "e_theta x e_zeta", "e_theta_r", "e_zeta_r"], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", ], ) def _e_theta_x_e_zeta_r(params, transforms, profiles, data, **kwargs): - a = cross(data["e_theta"], data["e_zeta"]) + a = data["e_theta x e_zeta"] a_r = cross(data["e_theta_r"], data["e_zeta"]) + cross( data["e_theta"], data["e_zeta_r"] ) @@ -150,13 +148,21 @@ def _e_theta_x_e_zeta_r(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_zeta", "e_theta_r", "e_zeta_r", "e_theta_rr", "e_zeta_rr"], + data=[ + "e_theta", + "e_zeta", + "e_theta_r", + "e_zeta_r", + "e_theta_rr", + "e_zeta_rr", + "e_theta x e_zeta", + ], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", ], ) def _e_theta_x_e_zeta_rr(params, transforms, profiles, data, **kwargs): - a = cross(data["e_theta"], data["e_zeta"]) + a = data["e_theta x e_zeta"] a_r = cross(data["e_theta_r"], data["e_zeta"]) + cross( data["e_theta"], data["e_zeta_r"] ) @@ -188,7 +194,14 @@ def _e_theta_x_e_zeta_rr(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_theta", "e_theta_z", "e_zeta", "e_zeta_z", "|e_theta x e_zeta|"], + data=[ + "e_theta", + "e_theta_z", + "e_zeta", + "e_zeta_z", + "|e_theta x e_zeta|", + "e_theta x e_zeta", + ], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", "desc.geometry.surface.FourierRZToroidalSurface", @@ -200,7 +213,7 @@ def _e_theta_x_e_zeta_z(params, transforms, profiles, data, **kwargs): cross(data["e_theta_z"], data["e_zeta"]) + cross(data["e_theta"], data["e_zeta_z"]) ), - cross(data["e_theta"], data["e_zeta"]), + data["e_theta x e_zeta"], ) / (data["|e_theta x e_zeta|"]) return data @@ -216,13 +229,13 @@ def _e_theta_x_e_zeta_z(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e^theta*sqrt(g)"], + data=["e_zeta x e_rho"], parameterization=[ "desc.equilibrium.equilibrium.Equilibrium", ], ) -def _e_zeta_x_e_rho(params, transforms, profiles, data, **kwargs): - data["|e_zeta x e_rho|"] = safenorm(data["e^theta*sqrt(g)"], axis=-1) +def _e_zeta_x_e_rho_norm(params, transforms, profiles, data, **kwargs): + data["|e_zeta x e_rho|"] = safenorm(data["e_zeta x e_rho"], axis=-1) return data @@ -354,11 +367,19 @@ def _e_rho_x_e_theta_rr(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_rho", "e_theta", "e_zeta", "e_rho_r", "e_theta_r", "e_zeta_r"], + data=[ + "e_rho", + "e_theta", + "e_zeta", + "e_rho_r", + "e_theta_r", + "e_zeta_r", + "e_theta x e_zeta", + ], ) def _sqrtg_r(params, transforms, profiles, data, **kwargs): data["sqrt(g)_r"] = ( - dot(data["e_rho_r"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_r"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_r"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_r"])) ) @@ -377,11 +398,19 @@ def _sqrtg_r(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_rho", "e_theta", "e_zeta", "e_rho_t", "e_theta_t", "e_zeta_t"], + data=[ + "e_rho", + "e_theta", + "e_zeta", + "e_rho_t", + "e_theta_t", + "e_zeta_t", + "e_theta x e_zeta", + ], ) def _sqrtg_t(params, transforms, profiles, data, **kwargs): data["sqrt(g)_t"] = ( - dot(data["e_rho_t"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_t"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_t"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_t"])) ) @@ -400,11 +429,19 @@ def _sqrtg_t(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e_rho", "e_theta", "e_zeta", "e_rho_z", "e_theta_z", "e_zeta_z"], + data=[ + "e_rho", + "e_theta", + "e_zeta", + "e_rho_z", + "e_theta_z", + "e_zeta_z", + "e_theta x e_zeta", + ], ) def _sqrtg_z(params, transforms, profiles, data, **kwargs): data["sqrt(g)_z"] = ( - dot(data["e_rho_z"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_z"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_z"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_z"])) ) @@ -433,11 +470,12 @@ def _sqrtg_z(params, transforms, profiles, data, **kwargs): "e_rho_rr", "e_theta_rr", "e_zeta_rr", + "e_theta x e_zeta", ], ) def _sqrtg_rr(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rr"] = ( - dot(data["e_rho_rr"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rr"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_rr"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_rr"])) + 2 * dot(data["e_rho_r"], cross(data["e_theta_r"], data["e_zeta"])) @@ -472,11 +510,12 @@ def _sqrtg_rr(params, transforms, profiles, data, **kwargs): "e_rho_rrr", "e_theta_rrr", "e_zeta_rrr", + "e_theta x e_zeta", ], ) def _sqrtg_rrr(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rrr"] = ( - dot(data["e_rho_rrr"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rrr"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_rrr"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_rrr"])) + 3 * dot(data["e_rho_rr"], cross(data["e_theta_r"], data["e_zeta"])) @@ -521,11 +560,12 @@ def _sqrtg_rrr(params, transforms, profiles, data, **kwargs): "e_rho_rrt", "e_theta_rrt", "e_zeta_rrt", + "e_theta x e_zeta", ], ) def _sqrtg_rrt(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rrt"] = ( - dot(data["e_rho_rrt"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rrt"], data["e_theta x e_zeta"]) + dot( data["e_rho_rr"], cross(data["e_theta_t"], data["e_zeta"]) @@ -585,11 +625,12 @@ def _sqrtg_rrt(params, transforms, profiles, data, **kwargs): "e_rho_tt", "e_theta_tt", "e_zeta_tt", + "e_theta x e_zeta", ], ) def _sqrtg_tt(params, transforms, profiles, data, **kwargs): data["sqrt(g)_tt"] = ( - dot(data["e_rho_tt"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_tt"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_tt"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_tt"])) + 2 * dot(data["e_rho_t"], cross(data["e_theta_t"], data["e_zeta"])) @@ -630,11 +671,12 @@ def _sqrtg_tt(params, transforms, profiles, data, **kwargs): "e_rho_rtt", "e_theta_rtt", "e_zeta_rtt", + "e_theta x e_zeta", ], ) def _sqrtg_rtt(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rtt"] = ( - dot(data["e_rho_rtt"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rtt"], data["e_theta x e_zeta"]) + dot(data["e_rho_r"], cross(data["e_theta_tt"], data["e_zeta"])) + dot(data["e_rho_r"], cross(data["e_theta"], data["e_zeta_tt"])) + 2 * dot(data["e_rho_rt"], cross(data["e_theta_t"], data["e_zeta"])) @@ -677,11 +719,12 @@ def _sqrtg_rtt(params, transforms, profiles, data, **kwargs): "e_rho_zz", "e_theta_zz", "e_zeta_zz", + "e_theta x e_zeta", ], ) def _sqrtg_zz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_zz"] = ( - dot(data["e_rho_zz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_zz"], data["e_theta x e_zeta"]) + dot(data["e_rho"], cross(data["e_theta_zz"], data["e_zeta"])) + dot(data["e_rho"], cross(data["e_theta"], data["e_zeta_zz"])) + 2 * dot(data["e_rho_z"], cross(data["e_theta_z"], data["e_zeta"])) @@ -722,11 +765,12 @@ def _sqrtg_zz(params, transforms, profiles, data, **kwargs): "e_rho_rzz", "e_theta_rzz", "e_zeta_rzz", + "e_theta x e_zeta", ], ) def _sqrtg_rzz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rzz"] = ( - dot(data["e_rho_rzz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rzz"], data["e_theta x e_zeta"]) + dot(data["e_rho_r"], cross(data["e_theta_zz"], data["e_zeta"])) + dot(data["e_rho_r"], cross(data["e_theta"], data["e_zeta_zz"])) + 2 * dot(data["e_rho_rz"], cross(data["e_theta_z"], data["e_zeta"])) @@ -772,11 +816,12 @@ def _sqrtg_rzz(params, transforms, profiles, data, **kwargs): "e_rho_rt", "e_theta_rt", "e_zeta_rt", + "e_theta x e_zeta", ], ) def _sqrtg_rt(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rt"] = ( - dot(data["e_rho_rt"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rt"], data["e_theta x e_zeta"]) + dot(data["e_rho_r"], cross(data["e_theta_t"], data["e_zeta"])) + dot(data["e_rho_r"], cross(data["e_theta"], data["e_zeta_t"])) + dot(data["e_rho"], cross(data["e_theta_rt"], data["e_zeta"])) @@ -813,11 +858,12 @@ def _sqrtg_rt(params, transforms, profiles, data, **kwargs): "e_rho_tz", "e_theta_tz", "e_zeta_tz", + "e_theta x e_zeta", ], ) def _sqrtg_tz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_tz"] = ( - dot(data["e_rho_tz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_tz"], data["e_theta x e_zeta"]) + dot(data["e_rho_z"], cross(data["e_theta_t"], data["e_zeta"])) + dot(data["e_rho_z"], cross(data["e_theta"], data["e_zeta_t"])) + dot(data["e_rho_t"], cross(data["e_theta_z"], data["e_zeta"])) @@ -866,11 +912,12 @@ def _sqrtg_tz(params, transforms, profiles, data, **kwargs): "e_rho_rtz", "e_theta_rtz", "e_zeta_rtz", + "e_theta x e_zeta", ], ) def _sqrtg_rtz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rtz"] = ( - dot(data["e_rho_rtz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rtz"], data["e_theta x e_zeta"]) + dot( data["e_rho_rz"], cross(data["e_theta_t"], data["e_zeta"]) @@ -944,11 +991,12 @@ def _sqrtg_rtz(params, transforms, profiles, data, **kwargs): "e_rho_rz", "e_theta_rz", "e_zeta_rz", + "e_theta x e_zeta", ], ) def _sqrtg_rz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rz"] = ( - dot(data["e_rho_rz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rz"], data["e_theta x e_zeta"]) + dot(data["e_rho_r"], cross(data["e_theta_z"], data["e_zeta"])) + dot(data["e_rho_r"], cross(data["e_theta"], data["e_zeta_z"])) + dot(data["e_rho_z"], cross(data["e_theta_r"], data["e_zeta"])) @@ -991,11 +1039,12 @@ def _sqrtg_rz(params, transforms, profiles, data, **kwargs): "e_rho_rrz", "e_theta_rrz", "e_zeta_rrz", + "e_theta x e_zeta", ], ) def _sqrtg_rrz(params, transforms, profiles, data, **kwargs): data["sqrt(g)_rrz"] = ( - dot(data["e_rho_rrz"], cross(data["e_theta"], data["e_zeta"])) + dot(data["e_rho_rrz"], data["e_theta x e_zeta"]) + dot( data["e_rho_rr"], cross(data["e_theta_z"], data["e_zeta"]) @@ -1791,6 +1840,7 @@ def _g_sup_zz_z(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["g^rr"], + aliases=["|e^rho|"], ) def _gradrho(params, transforms, profiles, data, **kwargs): data["|grad(rho)|"] = jnp.sqrt(data["g^rr"]) @@ -1871,6 +1921,7 @@ def _gradpsi_mag2(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["g^tt"], + aliases=["|e^theta|"], ) def _gradtheta(params, transforms, profiles, data, **kwargs): data["|grad(theta)|"] = jnp.sqrt(data["g^tt"]) @@ -1889,6 +1940,7 @@ def _gradtheta(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["g^zz"], + aliases=["|e^zeta|"], ) def _gradzeta(params, transforms, profiles, data, **kwargs): data["|grad(zeta)|"] = jnp.sqrt(data["g^zz"]) diff --git a/desc/compute/_surface.py b/desc/compute/_surface.py index d63622f246..ec855b9fb7 100644 --- a/desc/compute/_surface.py +++ b/desc/compute/_surface.py @@ -163,7 +163,7 @@ def _Phi_z_CurrentPotentialField(params, transforms, profiles, data, **kwargs): ], ) def _K_sup_theta_CurrentPotentialField(params, transforms, profiles, data, **kwargs): - data["K^theta"] = -data["Phi_z"] * (1 / data["|e_theta x e_zeta|"]) + data["K^theta"] = -data["Phi_z"] / data["|e_theta x e_zeta|"] return data @@ -185,7 +185,7 @@ def _K_sup_theta_CurrentPotentialField(params, transforms, profiles, data, **kwa ], ) def _K_sup_zeta_CurrentPotentialField(params, transforms, profiles, data, **kwargs): - data["K^zeta"] = data["Phi_t"] * (1 / data["|e_theta x e_zeta|"]) + data["K^zeta"] = data["Phi_t"] / data["|e_theta x e_zeta|"] return data @@ -208,7 +208,8 @@ def _K_sup_zeta_CurrentPotentialField(params, transforms, profiles, data, **kwar ], ) def _K_CurrentPotentialField(params, transforms, profiles, data, **kwargs): - data["K"] = (data["K^zeta"] * data["e_zeta"].T).T + ( - data["K^theta"] * data["e_theta"].T - ).T + data["K"] = ( + data["K^zeta"][:, jnp.newaxis] * data["e_zeta"] + + data["K^theta"][:, jnp.newaxis] * data["e_theta"] + ) return data diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 747fcd6315..88bd5526d2 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -48,24 +48,24 @@ def assign_alias_data( def register_compute_fun( # noqa: C901 + *, name, + aliases=None, label, units, units_long, description, dim, + coordinates, params, transforms, profiles, - coordinates, data, axis_limit_data=None, - aliases=None, - parameterization="desc.equilibrium.equilibrium.Equilibrium", resolution_requirement="", grid_requirement=None, source_grid_requirement=None, - *, + parameterization="desc.equilibrium.equilibrium.Equilibrium", public=True, **kwargs, ): @@ -76,6 +76,9 @@ def register_compute_fun( # noqa: C901 name : str Name of the quantity. This will be used as the key used to compute the quantity in `compute` and its name in the data dictionary. + aliases : list of str + Aliases of `name`. Will be stored in the data dictionary as a copy of `name`s + data. label : str Title of the quantity in LaTeX format. units : str @@ -87,25 +90,19 @@ def register_compute_fun( # noqa: C901 dim : int Dimension of the quantity: 0-D (global qty), 1-D (local scalar qty), or 3-D (local vector qty). + coordinates : str + Coordinate dependency. IE, "rtz" for a function of rho, theta, zeta, or "r" for + a flux function, etc. params : list of str Parameters of equilibrium needed to compute quantity, eg "R_lmn", "Z_lmn" transforms : dict Dictionary of keys and derivative orders [rho, theta, zeta] for R, Z, etc. profiles : list of str Names of profiles needed, eg "iota", "pressure" - coordinates : str - Coordinate dependency. IE, "rtz" for a function of rho, theta, zeta, or "r" for - a flux function, etc. data : list of str Names of other items in the data index needed to compute qty. axis_limit_data : list of str Names of other items in the data index needed to compute axis limit of qty. - aliases : list of str - Aliases of `name`. Will be stored in the data dictionary as a copy of `name`s - data. - parameterization : str or list of str - Name of desc types the method is valid for. eg `'desc.geometry.FourierXYZCurve'` - or `'desc.equilibrium.Equilibrium'`. resolution_requirement : str Resolution requirements in coordinates. I.e. "r" expects radial resolution in the grid. Likewise, "rtz" is shorthand for "rho, theta, zeta" and indicates @@ -128,6 +125,9 @@ def register_compute_fun( # noqa: C901 which will allow accessing the Clebsch-Type rho, alpha, zeta coordinates in ``transforms["grid"].source_grid``` that correspond to the DESC rho, theta, zeta coordinates in ``transforms["grid"]``. + parameterization : str or list of str + Name of desc types the method is valid for. eg `'desc.geometry.FourierXYZCurve'` + or `'desc.equilibrium.Equilibrium'`. public : bool Whether to include this quantity in the public documentation. Default is true. @@ -286,6 +286,15 @@ def _decorator(func): "desc.geometry.core.Curve", ], "desc.magnetic_fields._core.OmnigenousField": [], + "desc.magnetic_fields._laplace.SourceFreeField": [ + "desc.geometry.surface.FourierRZToroidalSurface", + "desc.geometry.core.Surface", + ], + "desc.magnetic_fields._laplace.FreeSurfaceOuterField": [ + "desc.magnetic_fields._laplace.SourceFreeField", + "desc.geometry.surface.FourierRZToroidalSurface", + "desc.geometry.core.Surface", + ], } data_index = {p: {} for p in _class_inheritance.keys()} all_kwargs = {p: {} for p in _class_inheritance.keys()} diff --git a/desc/compute/utils.py b/desc/compute/utils.py index bd07b54122..9f93643f34 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -1,4 +1,4 @@ -"""Functions for flux surface averages and vector algebra operations.""" +"""Utilities for computing dependencies and transforms.""" import copy import inspect @@ -9,7 +9,7 @@ from desc.backend import execute_on_cpu, jnp from desc.grid import Grid -from ..utils import errorif, rpz2xyz, rpz2xyz_vec +from ..utils import errorif, rpz2xyz, rpz2xyz_vec, setdefault, warnif from .data_index import allowed_kwargs, data_index, deprecated_names # map from profile name to equilibrium parameter name @@ -36,7 +36,14 @@ def _parse_parameterization(p): def compute( # noqa: C901 - parameterization, names, params, transforms, profiles, data=None, **kwargs + parameterization, + names, + params, + transforms, + profiles, + data=None, + RpZ_data=None, + **kwargs, ): """Compute the quantity given by name on grid. @@ -50,7 +57,7 @@ def compute( # noqa: C901 Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc. Defaults to attributes of self. transforms : dict of Transform - Transforms for R, Z, lambda, etc. Default is to build from grid + Transforms for R, Z, lambda, etc. Default is to build from grid. profiles : dict of Profile Profile objects for pressure, iota, current, etc. Defaults to attributes of self @@ -59,13 +66,24 @@ def compute( # noqa: C901 Any vector v = vΒΉ RΜ‚ + vΒ² Ο•Μ‚ + vΒ³ ZΜ‚ should be given in components v = [vΒΉ, vΒ², vΒ³] where RΜ‚, Ο•Μ‚, ZΜ‚ are the normalized basis vectors of the cylindrical coordinates R, Ο•, Z. + RpZ_data : dict[str, jnp.ndarray] + Data evaluated so far on the (R, Ο•, Z) coordinates in this dictionary. + Should store the three entries ``"R"``, ``"phi"``, and ``"Z"`` + if the intention is to compute something at these coordinates. Returns ------- - data : dict of ndarray - Computed quantity and intermediate variables. + data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on the + grid attached to the transforms. + RpZ_data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on the + (R, Ο•, Z) coordinates in ``RpZ_data``. If ``RpZ_data`` + was not given then this dictionary will not be returned. """ + return_RpZ_data = RpZ_data is not None + basis = kwargs.pop("basis", "rpz").lower() errorif(basis not in {"rpz", "xyz"}, NotImplementedError) p = _parse_parameterization(parameterization) @@ -77,20 +95,18 @@ def compute( # noqa: C901 with warnings.catch_warnings(): warnings.simplefilter("always", DeprecationWarning) for name in names: - if name not in data_index[p]: - raise ValueError( - f"Unrecognized value '{name}' for parameterization {p}." - ) - if name in list(deprecated_names.keys()): - warnings.warn( - f"Variable name {name} is deprecated and will be removed in a " - f"future DESC version, use name {deprecated_names[name]} " - "instead.", - DeprecationWarning, - ) + errorif( + name not in data_index[p], + msg=f"Unrecognized value '{name}' for parameterization {p}.", + ) + warnif( + name in list(deprecated_names.keys()), + DeprecationWarning, + f"Variable name {name} is deprecated and will be removed in a future " + f"DESC version, use name {deprecated_names.get(name, None)} instead.", + ) bad_kwargs = kwargs.keys() - allowed_kwargs - if len(bad_kwargs) > 0: - raise ValueError(f"Unrecognized argument(s): {bad_kwargs}") + errorif(bad_kwargs, msg=f"Unrecognized argument(s): {bad_kwargs}") for name in names: assert _has_params(name, params, p), f"Don't have params to compute {name}" @@ -135,6 +151,9 @@ def check_fun(name): if data is None: data = {} + # Need to query this before JIT barriers detach them from each other. + same_grid = data is RpZ_data + data = _compute( p, names, @@ -142,13 +161,42 @@ def check_fun(name): transforms=transforms, profiles=profiles, data=data, + RpZ_data=RpZ_data, **kwargs, ) + if return_RpZ_data: + if same_grid: + RpZ_data = data + RpZ_data = _compute_RpZ_data( + p, + names, + params=params, + transforms=transforms, + profiles=profiles, + data=data, + RpZ_data=RpZ_data, + **kwargs, + ) + if same_grid: + data = RpZ_data + + data = _convert_basis(p, data, basis) + if return_RpZ_data: + if data is not RpZ_data: + RpZ_data = _convert_basis(p, RpZ_data, basis) + return data, RpZ_data + else: + return data + + +def _convert_basis(p, data, basis): # convert data from default 'rpz' basis to 'xyz' basis, if requested by the user if basis == "xyz": for name in data.keys(): + if name == "potential data": + continue errorif( data_index[p][name]["dim"] == (3, 3), NotImplementedError, @@ -159,7 +207,6 @@ def check_fun(name): data[name] = rpz2xyz(data[name]) else: data[name] = rpz2xyz_vec(data[name], phi=data["phi"]) - return data @@ -177,6 +224,13 @@ def _compute( using recursion to compute dependencies. If you want to call this function, you cannot give the argument basis='xyz' since that will break the recursion. In that case, either call above function or manually convert the output to xyz basis. + + Returns + ------- + data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on the + grid attached to the transforms. + """ assert kwargs.get("basis", "rpz") == "rpz", "_compute only works in rpz coordinates" parameterization = _parse_parameterization(parameterization) @@ -189,6 +243,7 @@ def _compute( if name in data: # don't compute something that's already been computed continue + if not has_data_dependencies( parameterization, name, data, transforms["grid"].axis.size ): @@ -215,12 +270,82 @@ def _compute( **kwargs, ) # now compute the quantity - data = data_index[parameterization][name]["fun"]( - params=params, transforms=transforms, profiles=profiles, data=data, **kwargs - ) + if data_index[parameterization][name]["coordinates"] != "RpZ": + data = data_index[parameterization][name]["fun"]( + params=params, + transforms=transforms, + profiles=profiles, + data=data, + **kwargs, + ) return data +def _compute_RpZ_data( + parameterization, + names, + params, + transforms, + profiles, + data, + RpZ_data, + **kwargs, +): + """Same as above but without checking inputs for faster recursion. + + Any vector v = vΒΉ RΜ‚ + vΒ² Ο•Μ‚ + vΒ³ ZΜ‚ should be given in components + v = [vΒΉ, vΒ², vΒ³] where RΜ‚, Ο•Μ‚, ZΜ‚ are the normalized basis vectors + of the cylindrical coordinates R, Ο•, Z. + + We need to directly call this function in objectives, since the checks in above + function are not compatible with JIT. This function computes given names while + using recursion to compute dependencies. If you want to call this function, you + cannot give the argument basis='xyz' since that will break the recursion. In that + case, either call above function or manually convert the output to xyz basis. + + Returns + ------- + RpZ_data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on the + (R, Ο•, Z) coordinates in ``RpZ_data``. + + """ + assert kwargs.get("basis", "rpz") == "rpz", "_compute only works in rpz coordinates" + parameterization = _parse_parameterization(parameterization) + if isinstance(names, str): + names = [names] + + for name in names: + if ( + data_index[parameterization][name]["coordinates"] != "RpZ" + or name in RpZ_data + ): + continue + + if not has_RpZ_data_dependencies(parameterization, name, data, RpZ_data): + # then compute the missing dependencies + RpZ_data = _compute_RpZ_data( + parameterization, + data_index[parameterization][name]["dependencies"]["data"], + params=params, + transforms=transforms, + profiles=profiles, + data=data, + RpZ_data=RpZ_data, + **kwargs, + ) + # now compute the quantity + RpZ_data = data_index[parameterization][name]["fun"]( + params=params, + transforms=transforms, + profiles=profiles, + data=data, + RpZ_data=RpZ_data, + **kwargs, + ) + return RpZ_data + + @execute_on_cpu def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None): """Get list of keys needed to compute ``keys`` given already computed data. @@ -472,7 +597,7 @@ def get_profiles(keys, obj, grid=None, has_axis=False, basis="rpz"): @execute_on_cpu -def get_params(keys, obj, has_axis=False, basis="rpz"): +def get_params(keys, obj, has_axis=False, basis="rpz", params=None): """Get parameters needed to compute a given quantity. Parameters @@ -485,6 +610,8 @@ def get_params(keys, obj, has_axis=False, basis="rpz"): Whether the grid to compute on has a node on the magnetic axis. basis : {"rpz", "xyz"} Basis of computed quantities. + params : dict[str, jnp.ndarray] + Params computed so far. Returns ------- @@ -497,24 +624,34 @@ def get_params(keys, obj, has_axis=False, basis="rpz"): p = _parse_parameterization(obj) keys = [keys] if isinstance(keys, str) else keys deps = list(keys) + get_data_deps(keys, p, has_axis=has_axis, basis=basis) - params = [] + params_list = [] for key in deps: - params += data_index[p][key]["dependencies"]["params"] + params_list += data_index[p][key]["dependencies"]["params"] if isinstance(obj, str) or inspect.isclass(obj): - return params - temp_params = {} - for name in params: - p = getattr(obj, name) - if isinstance(p, dict): - temp_params[name] = p.copy() - else: - temp_params[name] = jnp.atleast_1d(p) - return temp_params + return params_list + + params = setdefault(params, {}) + for name in params_list: + if name not in params: + p = getattr(obj, name) + params[name] = ( + p.copy() + if isinstance(p, dict) + else (None if (p is None) else jnp.atleast_1d(p)) + ) + return params @execute_on_cpu -def get_transforms( - keys, obj, grid, jitable=False, has_axis=False, basis="rpz", **kwargs +def get_transforms( # noqa: C901 + keys, + obj, + grid, + jitable=False, + has_axis=False, + basis="rpz", + transforms=None, + **kwargs, ): """Get transforms needed to compute a given quantity on a given grid. @@ -532,10 +669,12 @@ def get_transforms( Whether the grid to compute on has a node on the magnetic axis. basis : {"rpz", "xyz"} Basis of computed quantities. + transforms : dict[str, Transform] + Transforms that are already computed. Returns ------- - transforms : dict of Transform + transforms : dict[str, Transform] Transforms needed to compute key. Keys for R, Z, L, etc @@ -548,8 +687,14 @@ def get_transforms( keys = [keys] if isinstance(keys, str) else keys has_axis = has_axis or (grid is not None and grid.axis.size) derivs = get_derivs(keys, obj, has_axis=has_axis, basis=basis) - transforms = {"grid": grid} + + transforms = setdefault(transforms, {}) + transforms.setdefault("grid", grid) + p = _parse_parameterization(obj) + for c in derivs.keys(): + if c in transforms: + continue if hasattr(obj, c + "_basis"): # regular stuff like R, Z, lambda etc. basis = getattr(obj, c + "_basis") # first check if we already have a transform with a compatible basis @@ -561,6 +706,12 @@ def get_transforms( ).astype(int) # don't build until we know all the derivs we need transform.change_derivatives(ders, build=False) + if ( + c == "Phi" + and p + == "desc.magnetic_fields._laplace.FreeSurfaceOuterField" + ): + transform.build_pinv() c_transform = transform break else: # if we didn't exit the loop early @@ -569,6 +720,12 @@ def get_transforms( basis, derivs=derivs[c], build=False, + build_pinv=c == "Phi" + and ( + p == "desc.magnetic_fields._laplace.SourceFreeField" + or p + == "desc.magnetic_fields._laplace.FreeSurfaceOuterField" + ), method=method, ) else: # don't perform checks if jitable=True as they are not jit-safe @@ -577,6 +734,11 @@ def get_transforms( basis, derivs=derivs[c], build=False, + build_pinv=c == "Phi" + and ( + p == "desc.magnetic_fields._laplace.SourceFreeField" + or p == "desc.magnetic_fields._laplace.FreeSurfaceOuterField" + ), method=method, ) transforms[c] = c_transform @@ -656,6 +818,13 @@ def has_data_dependencies(parameterization, qty, data, axis=False): ) +def has_RpZ_data_dependencies(parameterization, qty, data, RpZ_data): + """Determine if we have the data needed to compute qty.""" + p = _parse_parameterization(parameterization) + deps = data_index[p][qty]["dependencies"]["data"] + return all(d in data or d in RpZ_data for d in deps) + + def has_dependencies(parameterization, qty, params, transforms, profiles, data): """Determine if we have the ingredients needed to compute qty. diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 8f274a6301..618ab7f356 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -136,7 +136,8 @@ class Equilibrium(IOAble, Optimizable): anisotropy : Profile or ndarray Anisotropic pressure profile or array of mode numbers and spectral coefficients. Default is a PowerSeriesProfile with zero anisotropic pressure. - surface: Surface or ndarray shape(k,5) (optional) + surface: Surface or ndarray + Shape(k,5) (optional). Fixed boundary surface shape, as a Surface object or array of spectral mode numbers and coefficients of the form [l, m, n, R, Z]. Default is a FourierRZToroidalSurface with major radius 10 and minor radius 1 @@ -868,7 +869,7 @@ def compute( # noqa: C901 Name(s) of the quantity(s) to compute. grid : Grid, optional Grid of coordinates to evaluate at. Defaults to the quadrature grid. - params : dict of ndarray + params : dict[str, jnp.ndarray] Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc Defaults to attributes of self. transforms : dict of Transform @@ -889,7 +890,7 @@ def compute( # noqa: C901 Returns ------- - data : dict of ndarray + data : dict[str, ndarray] Computed quantity and intermediate variables. """ diff --git a/desc/geometry/core.py b/desc/geometry/core.py index 7f5bf60013..f488afe483 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -111,7 +111,7 @@ def compute( Returns ------- - data : dict of ndarray + data : dict[str, jnp.ndarray] Computed quantity and intermediate variables. """ @@ -436,7 +436,7 @@ def N(self): @property def sym(self): - """bool: Whether or not the surface is stellarator symmetric.""" + """bool: Whether the surface is stellarator symmetric.""" return self._sym def _compute_orientation(self): @@ -495,12 +495,16 @@ def compute( grid : Grid, optional Grid of coordinates to evaluate at. Defaults to a Linear grid for constant rho surfaces or a Quadrature grid for constant zeta surfaces. - params : dict of ndarray - Parameters from the equilibrium. Defaults to attributes of self. + params : dict[str, jnp.ndarray] + Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc + Defaults to attributes of self. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from grid - data : dict of ndarray - Data computed so far, generally output from other compute functions + data : dict[str, jnp.ndarray] + Data computed so far, generally output from other compute functions. + Any vector v = vΒΉ RΜ‚ + vΒ² Ο•Μ‚ + vΒ³ ZΜ‚ should be given in components + v = [vΒΉ, vΒ², vΒ³] where RΜ‚, Ο•Μ‚, ZΜ‚ are the normalized basis vectors + of the cylindrical coordinates R, Ο•, Z. override_grid : bool If True, override the user supplied grid if necessary and use a full resolution grid to compute quantities and then downsample to user requested @@ -509,10 +513,12 @@ def compute( Returns ------- - data : dict of ndarray + data : dict[str, jnp.ndarray] Computed quantity and intermediate variables. """ + RpZ_data = kwargs.pop("RpZ_data", None) + if isinstance(names, str): names = [names] if grid is None: @@ -532,8 +538,9 @@ def compute( f" instead got type {type(grid)}" ) - if params is None: - params = get_params(names, obj=self, basis=kwargs.get("basis", "rpz")) + params = get_params( + names, obj=self, basis=kwargs.get("basis", "rpz"), params=params + ) if transforms is None: transforms = get_transforms( names, @@ -609,6 +616,7 @@ def compute( transforms=transforms, profiles=profiles, data=data, + RpZ_data=RpZ_data, **kwargs, ) return data diff --git a/desc/grid.py b/desc/grid.py index 43d6963387..233e43db47 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -336,6 +336,16 @@ def unique_rho_idx(self): ) return self._unique_rho_idx + @property + def unique_theta(self): + """ndarray: Unique poloidal coordinates.""" + return self.compress(self.nodes[:, 1], "theta") + + @property + def unique_zeta(self): + """ndarray: Unique zeta coordinates.""" + return self.compress(self.nodes[:, 2], "zeta") + @property def unique_poloidal_idx(self): """ndarray: Indices of unique poloidal angle coordinates.""" diff --git a/desc/integrals/__init__.py b/desc/integrals/__init__.py index 6bb01b77f7..5c695da7ad 100644 --- a/desc/integrals/__init__.py +++ b/desc/integrals/__init__.py @@ -6,6 +6,7 @@ DFTInterpolator, FFTInterpolator, compute_B_plasma, + get_interpolator, singular_integral, virtual_casing_biot_savart, ) diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index c3424f8256..315b2f4d82 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -1,8 +1,9 @@ """Utilities for quadratures.""" from orthax.legendre import legder, legval +from scipy.special import roots_legendre -from desc.backend import eigh_tridiagonal, jnp, put +from desc.backend import eigh_tridiagonal, fori_loop, jnp, put from desc.utils import errorif @@ -317,3 +318,261 @@ def get_quadrature(quad, automorphism): w = w * grad_auto(x) x = auto(x) return x, w + + +def nfp_loop(source_grid, func, init_val): + """Calculate effects from source points on a single field period. + + The integral is computed on the full domain because the kernels of interest + have toroidal variation and are not NFP periodic. To that end, the integral + is computed on every field period and summed. The ``source_grid`` is the + first field period because DESC truncates the computational domain to + ΞΆ ∈ [0, 2Ο€/grid.NFP) and changes variables to the spectrally condensed + ΞΆ* = basis.NFP ΞΆ. The domain is shifted to the next field period by + incrementing the toroidal coordinate of the grid by 2Ο€/NFP. + + For an axisymmetric configuration, it is most efficient for ``source_grid`` to + be a single toroidal cross-section. To capture toroidal effects of the kernels + on those grids for axisymmetric configurations, we set a dummy value for NFP to + an integer larger than 1 so that the toroidal increment can move to a new spot. + + Parameters + ---------- + source_grid : _Grid + Grid with points ΞΆ ∈ [0, 2Ο€/grid.NFP). + func : callable + Should accept argument ``zeta_j`` denoting toroidal coordinates of + field period ``j``. + init_val : jnp.ndarray + Initial loop carry value. + + Returns + ------- + result : jnp.ndarray + Shape is ``init_val.shape``. + + """ + errorif( + source_grid.num_zeta == 1 and source_grid.NFP == 1, + msg="Source grid cannot compute toroidal effects.\n" + "Increase NFP of source grid to e.g. 64.", + ) + zeta = source_grid.nodes[:, 2] + NFP = source_grid.NFP + + def body(j, f): + return f + func(zeta + j * 2 * jnp.pi / NFP) + + return fori_loop(0, NFP, body, init_val) + + +def chi(r): + """Partition of unity function in polar coordinates. Eq 39 in [2]. + + Parameters + ---------- + r : jnp.ndarray + Absolute value of radial coordinate in polar domain. + + """ + return jnp.exp(-36 * jnp.abs(r) ** 8) + + +def eta(theta, zeta, theta0, zeta0, ht, hz, st, sz): + """Partition of unity function in rectangular coordinates. + + Consider the mapping from + (ΞΈ,ΞΆ) ∈ [-Ο€, Ο€) Γ— [-Ο€/NFP, Ο€/NFP) to (ρ,Ο‰) ∈ [βˆ’1, 1] Γ— [0, 2Ο€) + defined by + ΞΈ βˆ’ ΞΈβ‚€ = h₁ s₁/2 ρ sin Ο‰ + ΞΆ βˆ’ ΞΆβ‚€ = hβ‚‚ sβ‚‚/2 ρ cos Ο‰ + with Jacobian determinant norm h₁hβ‚‚ s₁sβ‚‚/4 |ρ|. + + In general in dimensions higher than one, the mapping that determines a + change of variable for integration must be bijective. This is satisfied + only if s₁ = 2Ο€/h₁ and sβ‚‚ = (2Ο€/NFP)/hβ‚‚. In the particular case the + integrand is nonzero in a subset of the domain, then the change of variable + need only be a bijective map where the function does not vanish, more + precisely, its set of compact support. + + The functions we integrate are proportional to Ξ·β‚€(ΞΈ,ΞΆ) = Ο‡β‚€(r) far from the + singularity at (ΞΈβ‚€,ΞΆβ‚€). Therefore, the support matches Ο‡β‚€(r)'s, assuming + this region is sufficiently large compared to the singular region. + Here Ο‡β‚€(r) has support where the argument r lies in [0, 1]. The map r + defines a coordinate mapping between the toroidal domain and a polar domain + such that the integration region in the polar domain (ρ,Ο‰) ∈ [βˆ’1, 1] Γ— [0, 2Ο€) + equals the compact support, and furthermore is a circular region around the + singular point in (ΞΈ,ΞΆ) geometry when s₁ Γ— sβ‚‚ denote the number of grid points + on a uniformly discretized toroidal domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β². + Ο‡β‚€ : r ↦ exp(βˆ’36r⁸) + + r : ρ, Ο‰ ↦ |ρ| + + r : ΞΈ, ΞΆ ↦ 2 [ (ΞΈβˆ’ΞΈβ‚€)Β²/(h₁s₁)Β² + (ΞΆβˆ’ΞΆβ‚€)Β²/(hβ‚‚sβ‚‚)Β² ]⁰ᐧ⁡ + + Hence, r β‰₯ 1 (r ≀ 1) outside (inside) the integration domain. + + The choice for the size of the support is determined by s₁ and sβ‚‚. + The optimal choice is dependent on the nature of the singularity e.g. if the + integrand decays quickly then the elliptical grid determined by s₁ and sβ‚‚ + can be made smaller and the integration will have higher resolution for the + same number of quadrature points. + + With the above definitions the support lies on an s₁ Γ— sβ‚‚ subset + of a field period which has ``num_theta`` Γ— ``num_zeta`` nodes total. + Since kernels are 2Ο€ periodic, the choice for sβ‚‚ should be multiplied by NFP. + Then the support lies on an s₁ Γ— sβ‚‚ subset of the full domain. For large NFP + devices such as Heliotron or tokamaks it is typical that s₁ β‰ͺ sβ‚‚. + + Parameters + ---------- + theta, zeta : jnp.ndarray + Coordinates of points to evaluate partition function Ξ·β‚€(ΞΈ,ΞΆ). + theta0, zeta0 : jnp.ndarray + Origin (ΞΈβ‚€,ΞΆβ‚€) where the partition Ξ·β‚€ is unity. + ht, hz : float + Grid step size in ΞΈ and ΞΆ. + st, sz : int + Extent of support is an ``st`` Γ— ``sz`` subset + of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``source_grid``. + Subset of ``source_grid.num_theta`` Γ— ``source_grid.num_zeta*source_grid.NFP``. + + """ + dt = jnp.abs(theta - theta0) + dz = jnp.abs(zeta - zeta0) + # The distance spans (dΞΈ,dΞΆ) ∈ [0, Ο€]Β², independent of NFP. + dt = jnp.minimum(dt, 2 * jnp.pi - dt) + dz = jnp.minimum(dz, 2 * jnp.pi - dz) + r = 2 * jnp.hypot(dt / (ht * st), dz / (hz * sz)) + return chi(r) + + +def _get_polar_quadrature(q): + """Polar nodes for quadrature around singular point. + + Parameters + ---------- + q : int + Order of quadrature in radial and azimuthal directions. + + Returns + ------- + r, w : ndarray + Radial and azimuthal coordinates. + dr, dw : ndarray + Radial and azimuthal spacing and quadrature weights. + + """ + Nr = Nw = q + r, dr = roots_legendre(Nr) + # integrate separately over [-1,0] and [0,1] + r1 = 1 / 2 * r - 1 / 2 + r2 = 1 / 2 * r + 1 / 2 + r = jnp.concatenate([r1, r2]) + dr = jnp.concatenate([dr, dr]) / 2 + w = jnp.linspace(0, jnp.pi, Nw, endpoint=False) + dw = jnp.ones_like(w) * jnp.pi / Nw + r, w = jnp.meshgrid(r, w) + r = r.ravel() + w = w.ravel() + dr, dw = jnp.meshgrid(dr, dw) + dr = dr.ravel() + dw = dw.ravel() + return r, w, dr, dw + + +def _vanilla_params(grid): + """Parameters for support size and quadrature resolution. + + These parameters do not account for grid anisotropy. + + Parameters + ---------- + grid : LinearGrid + Grid that can fft2. + + Returns + ------- + st : int + Extent of support is an ``st`` Γ— ``sz`` subset + of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. + Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. + sz : int + Extent of support is an ``st`` Γ— ``sz`` subset + of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. + Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. + q : int + Order of quadrature in radial and azimuthal directions. + + """ + Nt = grid.num_theta + Nz = grid.num_zeta * grid.NFP + q = int(jnp.sqrt(grid.num_nodes) // 2) + s = min(q, Nt, Nz) + return s, s, q + + +def _best_params(grid, ratio): + """Parameters for heuristic support size and quadrature resolution. + + These parameters account for global grid anisotropy which ensures + more robust convergence across a wider aspect ratio range. + + Parameters + ---------- + grid : LinearGrid + Grid that can fft2. + ratio : float or jnp.ndarray + Best ratio. + + Returns + ------- + st : int + Extent of support is an ``st`` Γ— ``sz`` subset + of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. + Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. + sz : int + Extent of support is an ``st`` Γ— ``sz`` subset + of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. + Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. + q : int + Order of quadrature in radial and azimuthal directions. + + """ + assert grid.can_fft2 + Nt = grid.num_theta + Nz = grid.num_zeta * grid.NFP + q = int(jnp.sqrt(grid.num_nodes if (grid.num_zeta > 1) else (Nt * Nz)) // 2) + s = min(q, Nt, Nz) + # Size of singular region in real space = s * h * |e_.| + # For it to be a circle, choose radius ~ equal + # s_t * h_t * |e_t| = s_z * h_z * |e_z| + # s_z / s_t = h_t / h_z |e_t| / |e_z| = Nz*NFP/Nt |e_t| / |e_z| + # Denote ratio = < |e_z| / |e_t| > and + # s_ratio = s_z / s_t = Nz*NFP/Nt / ratio + # Also want sqrt(s_z*s_t) ~ s = q. + s_ratio = jnp.sqrt(Nz / Nt / ratio) + st = jnp.clip(jnp.ceil(s / s_ratio).astype(int), None, Nt) + sz = jnp.clip(jnp.ceil(s * s_ratio).astype(int), None, Nz) + if s_ratio.size == 1: + st = int(st) + sz = int(sz) + return st, sz, q + + +def _best_ratio(data): + """Ratio to make singular integration partition ~circle in real space. + + Parameters + ---------- + data : dict[str, jnp.ndarray] + Dictionary of data evaluated on single flux surface grid that ``can_fft2`` + with keys ``|e_theta x e_zeta|``, ``e_theta``, and ``e_zeta``. + + """ + scale = jnp.linalg.norm(data["e_zeta"], axis=-1) / jnp.linalg.norm( + data["e_theta"], axis=-1 + ) + return jnp.mean(scale * data["|e_theta x e_zeta|"]) / jnp.mean( + data["|e_theta x e_zeta|"] + ) diff --git a/desc/integrals/singularities.py b/desc/integrals/singularities.py index d180bc6289..0de7a3f334 100644 --- a/desc/integrals/singularities.py +++ b/desc/integrals/singularities.py @@ -1,291 +1,101 @@ """High order method for singular surface integrals, from Malhotra 2019.""" +import warnings from abc import ABC, abstractmethod +from functools import partial -import numpy as np -import scipy -from interpax import fft_interp2d -from interpax_fft import rfft2_modes, rfft2_vander +from interpax_fft import rfft2_modes, rfft2_vander, rfft_interp2d from scipy.constants import mu_0 -from desc.backend import fori_loop, jnp, rfft2 +from desc.backend import jit, jnp, rfft2 from desc.batching import batch_map, vmap_chunked -from desc.grid import LinearGrid +from desc.grid import LinearGrid # noqa: F401 +from desc.integrals.quad_utils import ( + _best_params, + _best_ratio, + _get_polar_quadrature, + chi, + eta, + nfp_loop, +) from desc.io import IOAble from desc.utils import ( + apply, check_posint, - errorif, + dot, parse_argname_change, rpz2xyz, rpz2xyz_vec, safediv, safenorm, + setdefault, warnif, xyz2rpz_vec, ) -def _chi(r): - """Partition of unity function in polar coordinates. Eq 39 in [2]. - - Parameters - ---------- - r : jnp.ndarray - Absolute value of radial coordinate in polar domain. - - """ - return jnp.exp(-36 * jnp.abs(r) ** 8) - - -def _eta(theta, zeta, theta0, zeta0, ht, hz, st, sz): - """Partition of unity function in rectangular coordinates. - - Consider the mapping from - (ΞΈ,ΞΆ) ∈ [-Ο€, Ο€) Γ— [-Ο€/NFP, Ο€/NFP) to (ρ,Ο‰) ∈ [βˆ’1, 1] Γ— [0, 2Ο€) - defined by - ΞΈ βˆ’ ΞΈβ‚€ = h₁ s₁/2 ρ sin Ο‰ - ΞΆ βˆ’ ΞΆβ‚€ = hβ‚‚ sβ‚‚/2 ρ cos Ο‰ - with Jacobian determinant norm h₁hβ‚‚ s₁sβ‚‚/4 |ρ|. - - In general in dimensions higher than one, the mapping that determines a - change of variable for integration must be bijective. This is satisfied - only if s₁ = 2Ο€/h₁ and sβ‚‚ = (2Ο€/NFP)/hβ‚‚. In the particular case the - integrand is nonzero in a subset of the domain, then the change of variable - need only be a bijective map where the function does not vanish, more - precisely, its set of compact support. - - The functions we integrate are proportional to Ξ·β‚€(ΞΈ,ΞΆ) = Ο‡β‚€(r) far from the - singularity at (ΞΈβ‚€,ΞΆβ‚€). Therefore, the support matches Ο‡β‚€(r)'s, assuming - this region is sufficiently large compared to the singular region. - Here Ο‡β‚€(r) has support where the argument r lies in [0, 1]. The map r - defines a coordinate mapping between the toroidal domain and a polar domain - such that the integration region in the polar domain (ρ,Ο‰) ∈ [βˆ’1, 1] Γ— [0, 2Ο€) - equals the compact support, and furthermore is a circular region around the - singular point in (ΞΈ,ΞΆ) geometry when s₁ Γ— sβ‚‚ denote the number of grid points - on a uniformly discretized toroidal domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β². - Ο‡β‚€ : r ↦ exp(βˆ’36r⁸) - - r : ρ, Ο‰ ↦ |ρ| - - r : ΞΈ, ΞΆ ↦ 2 [ (ΞΈβˆ’ΞΈβ‚€)Β²/(h₁s₁)Β² + (ΞΆβˆ’ΞΆβ‚€)Β²/(hβ‚‚sβ‚‚)Β² ]⁰ᐧ⁡ - - Hence, r β‰₯ 1 (r ≀ 1) outside (inside) the integration domain. - - The choice for the size of the support is determined by s₁ and sβ‚‚. - The optimal choice is dependent on the nature of the singularity e.g. if the - integrand decays quickly then the elliptical grid determined by s₁ and sβ‚‚ - can be made smaller and the integration will have higher resolution for the - same number of quadrature points. - - With the above definitions the support lies on an s₁ Γ— sβ‚‚ subset - of a field period which has ``num_theta`` Γ— ``num_zeta`` nodes total. - Since kernels are 2Ο€ periodic, the choice for sβ‚‚ should be multiplied by NFP. - Then the support lies on an s₁ Γ— sβ‚‚ subset of the full domain. For large NFP - devices such as Heliotron or tokamaks it is typical that s₁ β‰ͺ sβ‚‚. - - Parameters - ---------- - theta, zeta : jnp.ndarray - Coordinates of points to evaluate partition function Ξ·β‚€(ΞΈ,ΞΆ). - theta0, zeta0 : jnp.ndarray - Origin (ΞΈβ‚€,ΞΆβ‚€) where the partition Ξ·β‚€ is unity. - ht, hz : float - Grid step size in ΞΈ and ΞΆ. - st, sz : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``source_grid``. - Subset of ``source_grid.num_theta`` Γ— ``source_grid.num_zeta*source_grid.NFP``. - - """ - dt = jnp.abs(theta - theta0) - dz = jnp.abs(zeta - zeta0) - # The distance spans (dΞΈ,dΞΆ) ∈ [0, Ο€]Β², independent of NFP. - dt = jnp.minimum(dt, 2 * jnp.pi - dt) - dz = jnp.minimum(dz, 2 * jnp.pi - dz) - r = 2 * jnp.hypot(dt / (ht * st), dz / (hz * sz)) - return _chi(r) - - -def _vanilla_params(grid): - """Parameters for support size and quadrature resolution. - - These parameters do not account for grid anisotropy. - - Parameters - ---------- - grid : LinearGrid - Grid that can fft2. - - Returns - ------- - st : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - sz : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - q : int - Order of quadrature in radial and azimuthal directions. - - """ - Nt = grid.num_theta - Nz = grid.num_zeta * grid.NFP - q = int(jnp.sqrt(grid.num_nodes) / 2) - s = min(q, Nt, Nz) - return s, s, q - - -def best_params(grid, ratio): - """Parameters for heuristic support size and quadrature resolution. - - These parameters account for global grid anisotropy which ensures - more robust convergence across a wider aspect ratio range. - - Parameters - ---------- - grid : LinearGrid - Grid that can fft2. - ratio : float - Mean best ratio. - - Returns - ------- - st : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - sz : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - q : int - Order of quadrature in radial and azimuthal directions. - - """ - assert grid.can_fft2 - Nt = grid.num_theta - Nz = grid.num_zeta * grid.NFP - if grid.num_zeta > 1: # actually has toroidal resolution - q = int(jnp.sqrt(grid.num_nodes) / 2) - else: # axisymmetry - q = int(jnp.sqrt(Nt * Nz) / 2) - s = min(q, Nt, Nz) - # Size of singular region in real space = s * h * |e_.| - # For it to be a circle, choose radius ~ equal - # s_t * h_t * |e_t| = s_z * h_z * |e_z| - # s_z / s_t = h_t / h_z |e_t| / |e_z| = Nz*NFP/Nt |e_t| / |e_z| - # Denote ratio = < |e_z| / |e_t| > and - # s_ratio = s_z / s_t = Nz*NFP/Nt / ratio - # Also want sqrt(s_z*s_t) ~ s = q. - s_ratio = jnp.sqrt(Nz / Nt / ratio) - st = min(Nt, int(jnp.ceil(s / s_ratio))) - sz = min(Nz, int(jnp.ceil(s * s_ratio))) - return st, sz, q - - -def _local_params(grid, ratio): - """Parameters for heuristic support size and quadrature resolution. - - These parameters account for local grid anisotropy to ensure - more robust convergence across stronger geometric shaping. - - Parameters - ---------- - grid : LinearGrid - Grid that can fft2. - ratio : tuple - Mean best ratio and local ratio - - Returns - ------- - st : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - sz : int - Extent of support is an ``st`` Γ— ``sz`` subset - of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``grid``. - Subset of ``grid.num_theta`` Γ— ``grid.num_zeta*grid.NFP``. - q : int - Order of quadrature in radial and azimuthal directions. - - """ - assert grid.can_fft2 - Nt = grid.num_theta - Nz = grid.num_zeta * grid.NFP - if grid.num_zeta > 1: # actually has toroidal resolution - q = int(jnp.sqrt(grid.num_nodes) / 2) - else: # axisymmetry - q = int(jnp.sqrt(Nt * Nz) / 2) - s = min(q, Nt, Nz) - ratio = (ratio[0] + ratio[1]) / 2 - # same logic as heuristic params - s_ratio = jnp.sqrt(Nz / Nt / ratio) - st = min(Nt, int(jnp.ceil(s / s_ratio))) - sz = min(Nz, int(jnp.ceil(s * s_ratio))) - return st, sz, q - - -def best_ratio(data, return_local=False): - """Ratio to make singular integration partition ~circle in real space. +def get_interpolator( + eval_grid, + source_grid, + source_data, + st=None, + sz=None, + q=None, + *, + use_dft=False, + warn_fft=True, + **kwargs, +): + """Get interpolator from Cartesian to polar domain. Parameters ---------- - data : dict[str, jnp.ndarray] - Dictionary of data evaluated on single flux surface grid that - ``can_fft2`` with keys ``|e_theta x e_zeta|``, ``e_theta``, and ``e_zeta``. - return_local : bool - Whether to return the local ratio as well as the mean global ratio. + eval_grid, source_grid : Grid + Evaluation and source points for the integral transform. + source_data : dict[str, jnp.ndarray] + Dictionary of data evaluated on single flux surface grid that ``can_fft2`` + with keys ``|e_theta x e_zeta|``, ``e_theta``, and ``e_zeta``. + use_dft : bool + Whether to use matrix multiplication transform from spectral to physical domain + instead of inverse fast Fourier transform. + warn_fft : bool + Whether to warn if the interpolation will be lossy. Default is ``True``. Returns ------- - mean : float - Mean best ratio. + f : _BIESTInterpolator + Interpolator that uses the specified method. """ - local = jnp.linalg.norm(data["e_zeta"], axis=-1) / jnp.linalg.norm( - data["e_theta"], axis=-1 - ) - mean = jnp.mean(local * data["|e_theta x e_zeta|"]) / jnp.mean( - data["|e_theta x e_zeta|"] + if st is None or sz is None or q is None: + _st, _sz, _q = _best_params(source_grid, _best_ratio(source_data)) + st = setdefault(st, _st) + sz = setdefault(sz, _sz) + q = setdefault(q, _q) + + if use_dft: + f = DFTInterpolator(eval_grid, source_grid, st, sz, q) + else: + try: + f = FFTInterpolator(eval_grid, source_grid, st, sz, q, warn_fft=warn_fft) + except AssertionError as e: + warnings.warn( + "Could not build FFT interpolator because:\n" + + str(e) + + "\nSwitching to DFT interpolator which is more expensive.", + ) + f = DFTInterpolator(eval_grid, source_grid, st, sz, q) + use_dft = True + + # TODO (#1599). + warnif( + use_dft, + RuntimeWarning, + msg="Computations may be performed incorrectly for large matrices " + "due to open issues with JAX. Until this is fixed, it is recommended to " + "validate results against computations with a small choice for chunk size.", ) - return (mean, local) if return_local else mean - - -def _get_quadrature_nodes(q): - """Polar nodes for quadrature around singular point. - - Parameters - ---------- - q : int - Order of quadrature in radial and azimuthal directions. - - Returns - ------- - r, w : ndarray - Radial and azimuthal coordinates. - dr, dw : ndarray - Radial and azimuthal spacing and quadrature weights. - - """ - Nr = Nw = q - r, dr = scipy.special.roots_legendre(Nr) - # integrate separately over [-1,0] and [0,1] - r1 = 1 / 2 * r - 1 / 2 - r2 = 1 / 2 * r + 1 / 2 - r = jnp.concatenate([r1, r2]) - dr = jnp.concatenate([dr, dr]) / 2 - w = jnp.linspace(0, jnp.pi, Nw, endpoint=False) - dw = jnp.ones_like(w) * jnp.pi / Nw - r, w = jnp.meshgrid(r, w) - r = r.flatten() - w = w.flatten() - dr, dw = jnp.meshgrid(dr, dw) - dr = dr.flatten() - dw = dw.flatten() - return r, w, dr, dw + return f class _BIESTInterpolator(IOAble, ABC): @@ -351,10 +161,20 @@ def __init__(self, eval_grid, source_grid, st, sz, q): self._q = q self._ht = 2 * jnp.pi / source_grid.num_theta self._hz = 2 * jnp.pi / source_grid.num_zeta / source_grid.NFP - r, w, _, _ = _get_quadrature_nodes(q) + r, w, _, _ = _get_polar_quadrature(q) self._shift_t = self._ht * st / 2 * r * jnp.sin(w) self._shift_z = self._hz * sz / 2 * r * jnp.cos(w) + @property + def eval_grid(self): + """Evaluation points.""" + return self._eval_grid + + @property + def source_grid(self): + """Source points for quadrature of kernels.""" + return self._source_grid + @property def st(self): """Extent of polar grid support. @@ -428,7 +248,7 @@ def __call__(self, f, i, *, vander=None): class FFTInterpolator(_BIESTInterpolator): - """FFT interpolation operator required for high order singular integration. + """FFT interpolation operator for high order polar quadrature. Parameters ---------- @@ -443,28 +263,32 @@ class FFTInterpolator(_BIESTInterpolator): Subset of ``source_grid.num_theta`` Γ— ``source_grid.num_zeta*source_grid.NFP``. q : int Order of quadrature in polar domain. + warn_fft : bool + Whether to warn if the interpolation will be lossy. Default is ``True``. """ - def __init__(self, eval_grid, source_grid, st, sz, q, **kwargs): + def __init__(self, eval_grid, source_grid, st, sz, q, *, warn_fft=True, **kwargs): st = parse_argname_change(st, kwargs, "s", "st") assert eval_grid.can_fft2, "Got False for eval_grid.can_fft2." warnif( - eval_grid.num_theta < source_grid.num_theta, - msg="Frequency spectrum of FFT interpolation will be truncated because " - "the evaluation grid has less resolution than the source grid.\n" + warn_fft and eval_grid.num_theta < (source_grid.num_theta // 2 + 1), + msg="Frequency spectrum of FFT interpolation will be truncated.\n" f"Got eval_grid.num_theta = {eval_grid.num_theta} < " - f"{source_grid.num_theta} = source_grid.num_theta.", + f"{source_grid.num_theta // 2 + 1} = source_grid.num_theta // 2 + 1.", ) warnif( - eval_grid.num_zeta < source_grid.num_zeta, - msg="Frequency spectrum of FFT interpolation will be truncated because " - "the evaluation grid has less resolution than the source grid.\n" + warn_fft and eval_grid.num_zeta < (source_grid.num_zeta // 2 + 1), + msg="Frequency spectrum of FFT interpolation will be truncated.\n" f"Got eval_grid.num_zeta = {eval_grid.num_zeta} < " - f"{source_grid.num_zeta} = source_grid.num_zeta.", + f"{source_grid.num_zeta // 2 + 1} = source_grid.num_zeta // 2 + 1.", ) super().__init__(eval_grid, source_grid, st, sz, q) + def fourier(self, f): + """Return Fourier transform of ``f`` as expected by this interpolator.""" + return self.source_grid.meshgrid_reshape(f, "rtz")[0] + def __call__(self, f, i, *, is_fourier=False, vander=None): """Interpolate ``f`` to polar node ``i`` around evaluation grid. @@ -494,23 +318,21 @@ def __call__(self, f, i, *, is_fourier=False, vander=None): Source data interpolated to ith polar node. """ - # Would need to add interpax code to DESC - # https://github.com/f0uriest/interpax/issues/53 - # for is_fourier to do anything. - shape = f.shape[1:] - return fft_interp2d( - self._source_grid.meshgrid_reshape(f, "rtz")[0], - n1=self._eval_grid.num_theta, - n2=self._eval_grid.num_zeta, + if not is_fourier: + f = self.fourier(f) + return rfft_interp2d( + f, + n1=self.eval_grid.num_theta, + n2=self.eval_grid.num_zeta, sx=self._shift_t[i], sy=self._shift_z[i], dx=self._ht, dy=self._hz, - ).reshape(self._eval_grid.num_nodes, *shape, order="F") + ).reshape(self.eval_grid.num_nodes, *f.shape[2:], order="F") class DFTInterpolator(_BIESTInterpolator): - """Fourier interpolation matrix required for high order singular integration. + """MMT interpolation operator for high order polar quadrature. Parameters ---------- @@ -523,7 +345,7 @@ class DFTInterpolator(_BIESTInterpolator): of the full domain (ΞΈ,ΞΆ) ∈ [0, 2Ο€)Β² of ``source_grid``. Subset of ``source_grid.num_theta`` Γ— ``source_grid.num_zeta*source_grid.NFP``. q : int - Order of quadrature in polar domain + Order of quadrature in polar domain. """ @@ -540,23 +362,23 @@ def __init__(self, eval_grid, source_grid, st, sz, q, **kwargs): def fourier(self, f): """Return Fourier transform of ``f`` as expected by this interpolator.""" - if (self._source_grid.num_zeta % 2) == 0: - i = (0, -1) - else: - i = 0 + i = (0, -1) if (self.source_grid.num_zeta % 2 == 0) else 0 return 2 * rfft2( - self._source_grid.meshgrid_reshape(f, "rtz")[0], + self.source_grid.meshgrid_reshape(f, "rtz")[0], axes=(0, 1), norm="forward", ).at[:, i].divide(2).reshape(-1, *f.shape[1:]) def vander_polar(self, i): """Return Vandermonde matrix for ith polar node.""" - theta = self._eval_grid.nodes[:, 1] + self._shift_t[i] - zeta = self._eval_grid.nodes[:, 2] + self._shift_z[i] - return rfft2_vander(theta, zeta, self._modes_fft, self._modes_rfft).reshape( - self._eval_grid.num_nodes, -1 - ) + return rfft2_vander( + self.eval_grid.unique_theta + self._shift_t[i], + self.eval_grid.unique_zeta + self._shift_z[i], + self._modes_fft, + self._modes_rfft, + inverse_idx_fft=self.eval_grid.inverse_theta_idx, + inverse_idx_rfft=self.eval_grid.inverse_zeta_idx, + ).reshape(self.eval_grid.num_nodes, -1) def __call__(self, f, i, *, is_fourier=False, vander=None): """Interpolate ``f`` to polar node ``i`` around evaluation grid. @@ -586,6 +408,32 @@ def __call__(self, f, i, *, is_fourier=False, vander=None): return jnp.real(vander @ f) +def _prune_data(eval_data, eval_grid, source_data, source_grid, kernel): + """Returns new dictionaries with only required data.""" + keys = ["R", "phi", "Z", "theta", "zeta"] + if hasattr(kernel, "eval_keys"): + keys = keys + kernel.eval_keys + + eval_data = apply(eval_data, jnp.asarray, keys) + if eval_grid is not None: + # Casting to JAX arrays reduces memory usage. + if "theta" not in eval_data: + eval_data["theta"] = jnp.asarray(eval_grid.nodes[:, 1]) + if "zeta" not in eval_data: + eval_data["zeta"] = jnp.asarray(eval_grid.nodes[:, 2]) + + # Can't prune Ο‰ because Ο‰ is need to interpolate Ο• in _singular_part. + keys = kernel.keys + ["omega", "theta", "zeta"] + source_data = apply(source_data, jnp.asarray, keys) + # to avoid adding keys to dictionary during iteration + if "theta" not in source_data: + source_data["theta"] = jnp.asarray(source_grid.nodes[:, 1]) + if "zeta" not in source_data: + source_data["zeta"] = jnp.asarray(source_grid.nodes[:, 2]) + + return eval_data, source_data + + def _nonsingular_part( eval_data, eval_grid, @@ -594,128 +442,104 @@ def _nonsingular_part( st, sz, kernel, + *, + ndim=None, chunk_size=None, ): """Integrate kernel over non-singular points. Generally follows sec 3.2.1 of [2]. + If ``eval_grid`` is ``None``, then takes Ξ· = 0. """ - source_theta = source_grid.nodes[:, 1] - # make sure source dict has zeta and phi to avoid - # adding keys to dict during iteration - source_zeta = source_data.setdefault("zeta", source_grid.nodes[:, 2]) - source_phi = source_data["phi"] - - eval_data = {key: eval_data[key] for key in kernel.keys if key in eval_data} - eval_data["theta"] = jnp.asarray(eval_grid.nodes[:, 1]) - eval_data["zeta"] = jnp.asarray(eval_grid.nodes[:, 2]) - + assert source_grid.can_fft2 ht = 2 * jnp.pi / source_grid.num_theta hz = 2 * jnp.pi / source_grid.num_zeta / source_grid.NFP - w = source_data["|e_theta x e_zeta|"][jnp.newaxis] * ht * hz - - def nfp_loop(j, f_data): - """Calculate effects from source points on a single field period. - - The surface integral is computed on the full domain because the kernels of - interest have toroidal variation and are not NFP periodic. To that end, the - integral is computed on every field period and summed. The ``source_grid`` is - the first field period because DESC truncates the computational domain to - ΞΆ ∈ [0, 2Ο€/grid.NFP) and changes variables to the spectrally condensed - ΞΆ* = basis.NFP ΞΆ. Therefore, we shift the domain to the next field period by - incrementing the toroidal coordinate of the grid by 2Ο€/NFP. For an axisymmetric - configuration, it is most efficient for ``source_grid`` to be a single toroidal - cross-section. To capture toroidal effects of the kernels on those grids for - axisymmetric configurations, we set a dummy value for NFP to an integer larger - than 1 so that the toroidal increment can move to a new spot. - """ - f, source_data = f_data - source_data["zeta"] = (source_zeta + j * 2 * jnp.pi / source_grid.NFP) % ( - 2 * jnp.pi - ) - source_data["phi"] = (source_phi + j * 2 * jnp.pi / source_grid.NFP) % ( - 2 * jnp.pi - ) - # nest this def to avoid having to pass the modified source_data around the loop - # easier to just close over it and let JAX figure it out + ndim = setdefault(ndim, kernel.ndim) + + source_zeta = source_data["zeta"] + source_phi = source_data["phi"] + + def func(zeta_j): + source_data["zeta"] = zeta_j + source_data["phi"] = zeta_j # TODO (#465) + + # nest this def and let JAX figure it out def eval_pt(eval_data_i): - k = kernel(eval_data_i, source_data).reshape( - -1, source_grid.num_nodes, kernel.ndim + _eta = ( + 0 + if eval_grid is None + else eta( + source_data["theta"], + source_data["zeta"], + eval_data_i["theta"][:, jnp.newaxis], + eval_data_i["zeta"][:, jnp.newaxis], + ht, + hz, + st, + sz, + ) ) - eta = _eta( - source_theta, - source_data["zeta"], - eval_data_i["theta"][:, jnp.newaxis], - eval_data_i["zeta"][:, jnp.newaxis], - ht, - hz, - st, - sz, + # absorbing (1 - eta) into ds to reduce number of flops by factor of ndim + return ( + kernel(eval_data_i, source_data, (ht * hz) * (1 - _eta)) + .reshape(-1, source_grid.num_nodes, ndim) + .sum(-2) ) - return jnp.sum(k * (w * (1 - eta))[..., jnp.newaxis], axis=1) - f += batch_map(eval_pt, eval_data, chunk_size).reshape( - eval_grid.num_nodes, kernel.ndim - ) - return f, source_data - - # This error should be raised earlier since this is not the only place - # we need the higher dummy NFP value, but the error message is more - # helpful with the nfp loop docstring. - errorif( - source_grid.num_zeta == 1 and source_grid.NFP == 1, - msg="Source grid cannot compute toroidal effects.\n" - "Increase NFP of source grid to e.g. 64.\n" - "This is required to " + nfp_loop.__doc__, - ) - f = jnp.zeros((eval_grid.num_nodes, kernel.ndim)) - f, _ = fori_loop(0, source_grid.NFP, nfp_loop, (f, source_data)) + return batch_map(eval_pt, eval_data, chunk_size).reshape(-1, ndim) - # undo rotation of source_zeta - source_data["zeta"] = source_zeta - source_data["phi"] = source_phi - # we sum vectors at different points, so they need to be in xyz for that to work - # but then need to convert vectors back to rpz + f = nfp_loop(source_grid, func, jnp.zeros((eval_data["phi"].size, ndim))) if kernel.ndim == 3: f = xyz2rpz_vec(f, phi=eval_data["phi"]) + # undo rotation of ΞΆ and Ο• + source_data["zeta"] = source_zeta + source_data["phi"] = source_phi + return f -def _singular_part(eval_data, source_data, kernel, interpolator, chunk_size=None): +def _singular_part( + eval_data, source_data, interpolator, kernel, *, known_map=None, chunk_size=None +): """Integrate singular point by interpolating to polar grid. Generally follows sec 3.2.2 of [2], with the following differences: - hyperparameter M replaced by ``st`` and ``sz``. - density sigma / function f is absorbed into kernel. + + TODO (#465): For nonzero Ο‰, the quadrature may not be symmetric about the + singular point. Hence the quadrature may not converge for Cauchy + principal values. Prove otherwise or remove singularity. + """ - eval_grid = interpolator._eval_grid - eval_theta = jnp.asarray(eval_grid.nodes[:, 1]) - eval_zeta = jnp.asarray(eval_grid.nodes[:, 2]) + eval_grid = interpolator.eval_grid + eval_theta = eval_grid.unique_theta + eval_zeta = eval_grid.unique_zeta - r, w, dr, dw = _get_quadrature_nodes(interpolator.q) + r, w, dr, dw = _get_polar_quadrature(interpolator.q) r = jnp.abs(r) # integrand of eq 38 in [2] except stuff that needs to be interpolated - v = ( - _chi(r) - * (interpolator.ht * interpolator.hz) - * (interpolator.st * interpolator.sz / 4) - * r - * dr - * dw - ) + v = interpolator.ht * interpolator.hz * interpolator.st * interpolator.sz / 4 + v = v * (chi(r) * r * dr * dw) - keys = set(["|e_theta x e_zeta|"] + kernel.keys) + keys = set(kernel.keys) if "phi" in keys: keys.remove("phi") # Ο• is not a periodic map of ΞΈ, ΞΆ. keys.add("omega") - keys = list(keys) - # Note that it is necessary to take the Fourier transforms of the + if known_map is not None: + map_name, map_fun = known_map + keys.remove(map_name) + # It is necessary to take the Fourier transforms of the # vector components of the orthonormal polar basis vectors RΜ‚, Ο•Μ‚, ZΜ‚. # Vector components of the Cartesian basis are not NFP periodic. - fsource = [interpolator.fourier(source_data[key]) for key in keys] + fsource = [ + (key, interpolator.fourier(source_data[key])) + for key in keys + if key in source_data + ] def polar_pt(i): """See sec 3.2.2 of [2]. @@ -727,53 +551,48 @@ def polar_pt(i): vander = interpolator.vander_polar(i) source_data_polar = { key: interpolator(val, i, is_fourier=True, vander=vander) - for key, val in zip(keys, fsource) + for key, val in fsource } - # Coordinates of the polar nodes around the evaluation point. - source_data_polar["theta"] = eval_theta + interpolator.shift_t[i] - source_data_polar["zeta"] = eval_zeta + interpolator.shift_z[i] + # coordinates of the polar nodes around the evaluation point + source_data_polar["theta"] = eval_data["theta"] + interpolator.shift_t[i] + source_data_polar["zeta"] = eval_data["zeta"] + interpolator.shift_z[i] if "omega" in keys: source_data_polar["phi"] = ( source_data_polar["zeta"] + source_data_polar["omega"] ) - # TODO (#465): For nonzero Ο‰, the quadrature may not be symmetric about the - # singular point for hypersingular kernels such as the Biot-Savart - # kernel. (Recall the singularity is in real space). Hence the quadrature - # may not converge to the desired Hadamard finite part. Prove otherwise or - # use uniform grid in ΞΈ, Ο• and map coordinates before starting the singular - # integral routine. - - # eval pts x source pts for 1 polar grid offset - k = kernel(eval_data, source_data_polar, diag=True).reshape( - eval_grid.num_nodes, kernel.ndim - ) - dS = v[i] * source_data_polar["|e_theta x e_zeta|"] - fi = k * dS[:, jnp.newaxis] - return fi + if known_map is not None: + source_data_polar[map_name] = map_fun( + eval_grid, + t=eval_theta + interpolator.shift_t[i], + z=eval_zeta + interpolator.shift_z[i], + ) + return kernel(eval_data, source_data_polar, v[i], diag=True) f = vmap_chunked( polar_pt, chunk_size=chunk_size, reduction=jnp.add, - # TODO (#1386): Infer jnp.add.reduce from reduction. - # https://github.com/jax-ml/jax/issues/23493. - chunk_reduction=lambda x: x.sum(axis=0), - )(jnp.arange(v.size)) - assert f.shape == (eval_grid.num_nodes, kernel.ndim) - - # we sum vectors at different points, so they need to be in xyz for that to work - # but then need to convert vectors back to rpz + chunk_reduction=_add_reduce, + )(jnp.arange(v.size)).reshape(eval_grid.num_nodes, -1) + if kernel.ndim == 3: f = xyz2rpz_vec(f, phi=eval_data["phi"]) return f +def _add_reduce(x): + return x.sum(0) + + def singular_integral( eval_data, source_data, - kernel, interpolator, + kernel, + *, + known_map=None, + ndim=None, chunk_size=None, **kwargs, ): @@ -789,23 +608,32 @@ def singular_integral( Parameters ---------- eval_data : dict - Dictionary of data at evaluation points (eval_grid passed to interpolator). - Keys should be those required by kernel as kernel.keys. Vector data should be - in rpz basis. + Dictionary of data at evaluation points (``interpolator.eval_grid``). + Should store (R, Ο•, Z) coordinates to evaluate field and any keys + in ``kernel.eval_keys``. + Vector data should be in rpz basis. source_data : dict - Dictionary of data at source points (source_grid passed to interpolator). Keys - should be those required by kernel as kernel.keys. Vector data should be in - rpz basis. + Dictionary of data at source points (``interpolatr.source_grid``). Keys + should be those required by kernel as ``kernel.keys``. + Vector data should be in rpz basis. + interpolator : _BIESTInterpolator + Function to interpolate from rectangular source grid to polar + source grid around each singular point. See ``FFTInterpolator`` or + ``DFTInterpolator`` kernel : str or callable Kernel function to evaluate. If str, one of the following: - '1_over_r' : 1 / |𝐫 βˆ’ 𝐫'| - 'nr_over_r3' : 𝐧'β‹…(𝐫 βˆ’ 𝐫') / |𝐫 βˆ’ 𝐫'|Β³ - 'biot_savart' : ΞΌβ‚€/4Ο€ 𝐊'Γ—(𝐫 βˆ’ 𝐫') / |𝐫 βˆ’ 𝐫'|Β³ - 'biot_savart_A' : ΞΌβ‚€/4Ο€ 𝐊' / |𝐫 βˆ’ 𝐫'| - If callable, should take 3 arguments: - eval_data : dict of data at evaluation points (primed) + '1_over_r' : 1 / |𝐫 βˆ’ 𝐫'| dS + 'nr_over_r3' : 𝐧'β‹…(𝐫 βˆ’ 𝐫') / |𝐫 βˆ’ 𝐫'|Β³ dS + 'biot_savart' : ΞΌβ‚€/4Ο€ 𝐊'Γ—(𝐫 βˆ’ 𝐫') / |𝐫 βˆ’ 𝐫'|Β³ dS + 'biot_savart_A' : ΞΌβ‚€/4Ο€ 𝐊' / |𝐫 βˆ’ 𝐫'| dS + If callable, should take 4 arguments: + eval_data : dict of data at evaluation points (primed) source_data : dict of data at source points (unprimed) - diag : boolean, whether to evaluate full cross interactions or just diagonal + ds : Surface area element (not weighted by β€–e_ΞΈ Γ— e_ΞΆβ€– Jacobian). + Broadcasts with shape + (eval_grid.num_nodes, source_grid.num_nodes). + diag : boolean, whether to evaluate full cross interactions + or just diagonal If a callable, should also have the attributes ``ndim`` and ``keys`` defined. ``ndim`` is an integer representing the dimensionality of the output function f, 1 if f is scalar, 3 if f is a vector, etc. @@ -814,10 +642,17 @@ def singular_integral( evaluation points. If vector valued, the input to the kernel function will be in rpz and output should be in xyz. - interpolator : _BIESTInterpolator - Function to interpolate from rectangular source grid to polar - source grid around each singular point. See ``FFTInterpolator`` or - ``DFTInterpolator`` + known_map : (str, callable) + Optional. If map used in kernel of singular integral is known, + then it is more efficient to provide a callable to compute it + rather than interpolating and evaluating a Fourier series. + First index should store the name of the map used in the kernel + e.g. "Phi (periodic)", and the second index should store the Python + callable that accepts a grid argument. + Should broadcast with shapes (..., source_grid.num_nodes, ndim). + ndim : int + Default is kernel.ndim. + In some applications it, may be useful to supply other values for batching. chunk_size : int or None Size to split computation into chunks. If no chunking should be done or the chunk size is the full input @@ -837,123 +672,258 @@ def singular_integral( """ chunk_size = parse_argname_change(chunk_size, kwargs, "loop", "chunk_size") - if chunk_size == 0: - chunk_size = None - # sanitize inputs, we need everything as jax arrays so they can be indexed - # properly in the loops - source_data = {key: jnp.asarray(val) for key, val in source_data.items()} - eval_data = {key: jnp.asarray(val) for key, val in eval_data.items()} + chunk_size = None if (chunk_size == 0) else chunk_size if isinstance(kernel, str): kernel = kernels[kernel] - out1 = _singular_part(eval_data, source_data, kernel, interpolator, chunk_size) + eval_grid = interpolator.eval_grid + source_grid = interpolator.source_grid + if kwargs.get("_prune_data", True): + eval_data, source_data = _prune_data( + eval_data, + eval_grid, + source_data, + source_grid, + kernel, + ) + out1 = _singular_part( + eval_data, + source_data, + interpolator, + kernel, + known_map=known_map, + chunk_size=chunk_size, + ) out2 = _nonsingular_part( eval_data, - interpolator._eval_grid, + eval_grid, source_data, - interpolator._source_grid, + source_grid, interpolator.st, interpolator.sz, kernel, - chunk_size, + ndim=ndim, + chunk_size=chunk_size, ) return out1 + out2 -def _kernel_nr_over_r3(eval_data, source_data, diag=False): - # n * r / |r|^3 - source_x = jnp.atleast_2d( - rpz2xyz(jnp.array([source_data["R"], source_data["phi"], source_data["Z"]]).T) +def _dx(eval_data, source_data, diag=False): + """Compute distance vector between eval and source points. + + Parameters + ---------- + eval_data : dict[str, jnp.ndarray] + x data evaluated on eval grid. + source_data : dict[str, jnp.ndarray] + y data evaluated on source grid. + diag : bool + Set to True to bypass outer product. + + Returns + ------- + dx : jnp.ndarray + The vector x-y where y is a source point and x is eval point, + in Cartesian coordinates. + Shape (num eval, num source, 3). + + """ + source_x = rpz2xyz( + jnp.column_stack([source_data["R"], source_data["phi"], source_data["Z"]]) ) - eval_x = jnp.atleast_2d( - rpz2xyz(jnp.array([eval_data["R"], eval_data["phi"], eval_data["Z"]]).T) + eval_x = rpz2xyz( + jnp.column_stack([eval_data["R"], eval_data["phi"], eval_data["Z"]]) ) - if diag: - dx = eval_x - source_x - else: - dx = eval_x[:, None] - source_x[None] - n = rpz2xyz_vec(source_data["e^rho"], phi=source_data["phi"]) - n = n / jnp.linalg.norm(n, axis=-1, keepdims=True) - r = safenorm(dx, axis=-1) - return safediv(jnp.sum(n * dx, axis=-1), r**3) + if not diag: + eval_x = eval_x[:, jnp.newaxis] + return eval_x - source_x -_kernel_nr_over_r3.ndim = 1 -_kernel_nr_over_r3.keys = ["R", "phi", "Z", "e^rho"] +_dx.keys = ["R", "phi", "Z"] -def _kernel_1_over_r(eval_data, source_data, diag=False): - # 1/|r| - source_x = jnp.atleast_2d( - rpz2xyz(jnp.array([source_data["R"], source_data["phi"], source_data["Z"]]).T) - ) - eval_x = jnp.atleast_2d( - rpz2xyz(jnp.array([eval_data["R"], eval_data["phi"], eval_data["Z"]]).T) +def _G(dx, keepdims=False): + """Fundamental solution to the Laplacian in ℝ³. + + Parameters + ---------- + dx : jnp.ndarray + The vector x-y where y is a source point and x is eval point, + in Cartesian coordinates. + Shape (num eval, num source, 3). + + Returns + ------- + G : jnp.ndarray + G(x-y) = -1/(4Ο€ β€–xβˆ’yβ€–). + Shape (num eval, num source). + + """ + return safediv(-1, 4 * jnp.pi * safenorm(dx, axis=-1, keepdims=keepdims)) + + +def _grad_G(dx): + """βˆ‡_x G(xβˆ’y) where G is the fundamental solution to the Laplacian in ℝ³. + + Parameters + ---------- + dx : jnp.ndarray + The vector x-y where y is a source point and x is eval point, + in Cartesian coordinates. + Shape (num eval, num source, 3). + + Returns + ------- + grad_G : jnp.ndarray + βˆ‡_x G(xβˆ’y) = (4Ο€)⁻¹ β€–xβˆ’y‖⁻³ (x-y). + Shape (num eval, num source). + + """ + return safediv(dx, 4 * jnp.pi * safenorm(dx, axis=-1, keepdims=True) ** 3) + + +def _kernel_1_over_r(eval_data, source_data, ds, diag=False): + """Returns -4Ο€ da(y) G(x-y) = β€–e_ΞΈ Γ— e_ΞΆβ€– dΞΈ dΞΆ β€–xβˆ’y‖⁻¹.""" + return ( + (-4 * jnp.pi * ds) + * source_data["|e_theta x e_zeta|"] + * _G(_dx(eval_data, source_data, diag)) ) - if diag: - dx = eval_x - source_x - else: - dx = eval_x[:, None] - source_x[None] - r = safenorm(dx, axis=-1) - return safediv(1, r) _kernel_1_over_r.ndim = 1 -_kernel_1_over_r.keys = ["R", "phi", "Z"] +_kernel_1_over_r.keys = _dx.keys + ["|e_theta x e_zeta|"] -def _kernel_biot_savart(eval_data, source_data, diag=False): - # K x r / |r|^3 - source_x = jnp.atleast_2d( - rpz2xyz(jnp.array([source_data["R"], source_data["phi"], source_data["Z"]]).T) - ) - eval_x = jnp.atleast_2d( - rpz2xyz(jnp.array([eval_data["R"], eval_data["phi"], eval_data["Z"]]).T) +def _kernel_nr_over_r3(eval_data, source_data, ds, diag=False): + """Returns 4Ο€ ds(y) β‹… βˆ‡_x G(xβˆ’y) = ds(y) β‹… β€–xβˆ’y‖⁻³ (x-y).""" + return (4 * jnp.pi * ds) * dot( + rpz2xyz_vec(source_data["e_theta x e_zeta"], phi=source_data["phi"]), + _grad_G(_dx(eval_data, source_data, diag)), ) - if diag: - dx = eval_x - source_x - else: - dx = eval_x[:, None] - source_x[None] + + +_kernel_nr_over_r3.ndim = 1 +_kernel_nr_over_r3.keys = _dx.keys + ["e_theta x e_zeta"] + + +def _kernel_biot_savart(eval_data, source_data, ds, diag=False): + """Returns (ΞΌβ‚€ K(y) x βˆ‡_x G(xβˆ’y)) da(y) = (ΞΌβ‚€/4Ο€) K(y) da(y) Γ— (x-y) β€–xβˆ’y‖⁻³.""" + if jnp.ndim(ds) > 0: + ds = ds[..., jnp.newaxis] K = rpz2xyz_vec(source_data["K_vc"], phi=source_data["phi"]) - num = jnp.cross(K, dx, axis=-1) - r = safenorm(dx, axis=-1)[..., None] - return mu_0 / 4 / jnp.pi * safediv(num, r**3) + return ds * jnp.cross( + mu_0 * K * source_data["|e_theta x e_zeta|"][:, jnp.newaxis], + _grad_G(_dx(eval_data, source_data, diag)), + ) _kernel_biot_savart.ndim = 3 -_kernel_biot_savart.keys = ["R", "phi", "Z", "K_vc"] +_kernel_biot_savart.keys = _dx.keys + ["K_vc", "|e_theta x e_zeta|"] -def _kernel_biot_savart_A(eval_data, source_data, diag=False): - # K / |r| - source_x = jnp.atleast_2d( - rpz2xyz(jnp.array([source_data["R"], source_data["phi"], source_data["Z"]]).T) - ) - eval_x = jnp.atleast_2d( - rpz2xyz(jnp.array([eval_data["R"], eval_data["phi"], eval_data["Z"]]).T) - ) - if diag: - dx = eval_x - source_x - else: - dx = eval_x[:, None] - source_x[None] - r = safenorm(dx, axis=-1)[..., None] +def _kernel_biot_savart_A(eval_data, source_data, ds, diag=False): + """Returns ds(y) (-ΞΌβ‚€K)(y) G(xβˆ’y) = (ΞΌβ‚€/4Ο€) ds(y) K(y) β€–xβˆ’y‖⁻¹.""" + if jnp.ndim(ds) > 0: + ds = ds[..., jnp.newaxis] K = rpz2xyz_vec(source_data["K_vc"], phi=source_data["phi"]) - return mu_0 / 4 / jnp.pi * safediv(K, r) + return ( + ds + * source_data["|e_theta x e_zeta|"][:, jnp.newaxis] + * (-mu_0 * K) + * _G(_dx(eval_data, source_data, diag), keepdims=True) + ) _kernel_biot_savart_A.ndim = 3 -_kernel_biot_savart_A.keys = ["R", "phi", "Z", "K_vc"] +_kernel_biot_savart_A.keys = _dx.keys + ["K_vc", "|e_theta x e_zeta|"] + + +def _kernel_BS_plus_grad_S(eval_data, source_data, ds, diag=False): + """Returns K(y) (Tesla) x βˆ‡_x G(xβˆ’y) da(y) + βˆ‡_x G(xβˆ’y) Bβ‚™(y) da(y).""" + if jnp.ndim(ds) > 0: + ds = ds[..., jnp.newaxis] + K = rpz2xyz_vec(source_data["K_vc (periodic)"], phi=source_data["phi"]) + a = source_data["|e_theta x e_zeta|"] + grad_G = _grad_G(_dx(eval_data, source_data, diag)) + return ds * ( + jnp.cross(K * a[:, jnp.newaxis], grad_G) + + grad_G * (source_data["B0*n"] * a)[:, jnp.newaxis] + ) + + +_kernel_BS_plus_grad_S.ndim = 3 +_kernel_BS_plus_grad_S.keys = _dx.keys + [ + "K_vc (periodic)", + "B0*n", + "|e_theta x e_zeta|", +] + + +def _kernel_monopole(eval_data, source_data, ds, diag=False): + """Kernel of single layer operator S[B0*n]: (B0*n)(y) G(x-y) da(y).""" + return ( + ds + * (source_data["|e_theta x e_zeta|"] * source_data["B0*n"]) + * _G(_dx(eval_data, source_data, diag)) + ) + + +_kernel_monopole.ndim = 1 +_kernel_monopole.keys = _dx.keys + ["B0*n", "|e_theta x e_zeta|"] + + +def _kernel_dipole(eval_data, source_data, ds, diag=False): + """Kernel of double layer operator D[Ξ¦]: Ξ¦(y)γ€ˆβˆ‡_x G(xβˆ’y),n(y)〉da(y).""" + out = ds * dot( + rpz2xyz_vec(source_data["e_theta x e_zeta"], phi=source_data["phi"]), + _grad_G(_dx(eval_data, source_data, diag)), + ) + if source_data["Phi (periodic)"].ndim > 1: + out = out[..., jnp.newaxis] + # Do operation with Ξ¦ at the end, so that the following + # outer product plus reduction is more likely to be fused. + return source_data["Phi (periodic)"] * out + + +_kernel_dipole.ndim = 1 +_kernel_dipole.keys = _dx.keys + ["e_theta x e_zeta", "Phi (periodic)"] +def _kernel_dipole_plus_half(eval_data, source_data, ds, diag=False): + """Kernel of operator (D[Ξ¦] + Ξ¦/2)(x).""" + eval_Phi = eval_data["Phi(x) (periodic)"] + if not diag: + eval_Phi = eval_Phi[:, jnp.newaxis] + out = ds * dot( + rpz2xyz_vec(source_data["e_theta x e_zeta"], phi=source_data["phi"]), + _grad_G(_dx(eval_data, source_data, diag)), + ) + if source_data["Phi (periodic)"].ndim > 1: + out = out[..., jnp.newaxis] + # Do operation with Ξ¦ at the end, so that the following + # outer product plus reduction is more likely to be fused. + return (source_data["Phi (periodic)"] - eval_Phi) * out + + +_kernel_dipole_plus_half.ndim = 1 +_kernel_dipole_plus_half.keys = _dx.keys + ["e_theta x e_zeta", "Phi (periodic)"] +_kernel_dipole_plus_half.eval_keys = ["Phi(x) (periodic)"] + kernels = { "1_over_r": _kernel_1_over_r, "nr_over_r3": _kernel_nr_over_r3, "biot_savart": _kernel_biot_savart, "biot_savart_A": _kernel_biot_savart_A, + "biot_savart_grad_S": _kernel_BS_plus_grad_S, + "monopole": _kernel_monopole, + "dipole": _kernel_dipole, + "dipole_plus_half": _kernel_dipole_plus_half, } +@partial(jit, static_argnames=["chunk_size", "loop"]) def virtual_casing_biot_savart( eval_data, source_data, interpolator, chunk_size=None, **kwargs ): @@ -970,27 +940,33 @@ def virtual_casing_biot_savart( This 3D integral can be converted to a 2D integral over the plasma boundary using the virtual casing principle [1]_ - 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ (𝐧' β‹… 𝐁(𝐫')) (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' - + ΞΌβ‚€/4Ο€ ∫ (𝐧' Γ— 𝐁(𝐫') Γ— (𝐫 βˆ’ 𝐫')/ |𝐫 βˆ’ 𝐫'|Β³ d²𝐫' - + 𝐁(𝐫)/2 + 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ (𝐧' β‹… 𝐁(𝐫')) * (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + + ΞΌβ‚€/4Ο€ ∫ (𝐧' Γ— 𝐁(𝐫')) Γ— (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + + 𝐁(𝐫)/2 Where 𝐁 is the total field on the surface and 𝐧' is the outward surface normal. Because the total field is tangent, the first term in the integrand is zero leaving - 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ K_vc(𝐫') Γ— (𝐫 βˆ’ 𝐫')/ |𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + 𝐁(𝐫)/2 + 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ K_vc(𝐫') Γ— (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + 𝐁(𝐫)/2 Where we have defined the virtual casing sheet current K_vc = 𝐧' Γ— 𝐁(𝐫') + References + ---------- + [1] Hanson, James D. "The virtual-casing principle and Helmholtz’s theorem." + Plasma Physics and Controlled Fusion 57.11 (2015): 115006. + Parameters ---------- eval_data : dict - Dictionary of data at evaluation points (eval_grid passed to interpolator). - Keys should be those required by kernel as kernel.keys. Vector data should be - in rpz basis. + Dictionary of data at evaluation points (``interpolator.eval_grid``). + Should store (R, Ο•, Z) coordinates to evaluate field and any keys + in ``kernel.eval_keys``. + Vector data should be in rpz basis. source_data : dict - Dictionary of data at source points (source_grid passed to interpolator). Keys - should be those required by kernel as kernel.keys. Vector data should be in - rpz basis. + Dictionary of data at source points (``interpolatr.source_grid``). Keys + should be those required by kernel as ``kernel.keys``. + Vector data should be in rpz basis. interpolator : _BIESTInterpolator Function to interpolate from rectangular source grid to polar source grid around each singular point. See ``FFTInterpolator`` or @@ -1005,18 +981,13 @@ def virtual_casing_biot_savart( f : ndarray, shape(eval_grid.num_nodes, kernel.ndim) Integral transform evaluated at eval_grid. Vectors are in rpz basis. - References - ---------- - .. [1] Hanson, James D. "The virtual-casing principle and Helmholtz’s theorem." - Plasma Physics and Controlled Fusion 57.11 (2015): 115006. - """ return singular_integral( eval_data, source_data, - _kernel_biot_savart, interpolator, - chunk_size, + _kernel_biot_savart, + chunk_size=chunk_size, **kwargs, ) @@ -1037,17 +1008,22 @@ def compute_B_plasma( This 3D integral can be converted to a 2D integral over the plasma boundary using the virtual casing principle [1]_ - 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ (𝐧' β‹… 𝐁(𝐫')) (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' - + ΞΌβ‚€/4Ο€ ∫ (𝐧' Γ— 𝐁(𝐫') Γ— (𝐫 βˆ’ 𝐫')/ |𝐫 βˆ’ 𝐫'|Β³ d²𝐫' - + 𝐁(𝐫)/2 + 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ (𝐧' β‹… 𝐁(𝐫')) * (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + + ΞΌβ‚€/4Ο€ ∫ (𝐧' Γ— 𝐁(𝐫')) Γ— (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + + 𝐁(𝐫)/2 Where 𝐁 is the total field on the surface and 𝐧' is the outward surface normal. Because the total field is tangent, the first term in the integrand is zero leaving - 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ K_vc(𝐫') Γ— (𝐫 βˆ’ 𝐫')/ |𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + 𝐁(𝐫)/2 + 𝐁α΅₯(𝐫) = ΞΌβ‚€/4Ο€ ∫ K_vc(𝐫') Γ— (𝐫 βˆ’ 𝐫')/|𝐫 βˆ’ 𝐫'|Β³ d²𝐫' + 𝐁(𝐫)/2 Where we have defined the virtual casing sheet current K_vc = 𝐧' Γ— 𝐁(𝐫') + References + ---------- + [1] Hanson, James D. "The virtual-casing principle and Helmholtz’s theorem." + Plasma Physics and Controlled Fusion 57.11 (2015): 115006. + Parameters ---------- eq : Equilibrium @@ -1069,41 +1045,29 @@ def compute_B_plasma( Magnetic field evaluated at eval_grid. If normal_only=False, vector B is in rpz basis. - References - ---------- - .. [1] Hanson, James D. "The virtual-casing principle and Helmholtz’s theorem." - Plasma Physics and Controlled Fusion 57.11 (2015): 115006. - """ if source_grid is None: source_grid = LinearGrid( - rho=np.array([1.0]), M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP if eq.N > 0 else 64, sym=False, ) - data_keys = ["K_vc", "B", "R", "phi", "Z", "e^rho", "n_rho", "|e_theta x e_zeta|"] - eval_data = eq.compute(data_keys, grid=eval_grid) - source_data = eq.compute(data_keys, grid=source_grid) - st, sz, q = best_params(source_grid, best_ratio(source_data)) - try: - interpolator = FFTInterpolator(eval_grid, source_grid, st, sz, q) - except AssertionError as e: - warnif( - True, - msg="Could not build fft interpolator, switching to dft which is slow." - "\nReason: " + str(e), - ) - interpolator = DFTInterpolator(eval_grid, source_grid, st, sz, q) + eval_data = eq.compute(_dx.keys + ["B", "n_rho"], grid=eval_grid) + source_data = eq.compute( + _kernel_biot_savart.keys + ["|e_theta x e_zeta|"], grid=source_grid + ) if hasattr(eq.surface, "Phi_mn"): - source_data["K_vc"] += eq.surface.compute("K", grid=source_grid)["K"] + source_data = eq.surface.compute("K", grid=source_grid, data=source_data) + source_data["K_vc"] += source_data["K"] + + interpolator = get_interpolator(eval_grid, source_grid, source_data) Bplasma = virtual_casing_biot_savart( - eval_data, source_data, interpolator, chunk_size + eval_data, source_data, interpolator, chunk_size=chunk_size ) # need extra factor of B/2 bc we're evaluating on plasma surface - Bplasma = Bplasma + eval_data["B"] / 2 + Bplasma += eval_data["B"] / 2 if normal_only: - Bplasma = jnp.sum(Bplasma * eval_data["n_rho"], axis=1) + Bplasma = dot(Bplasma, eval_data["n_rho"]) return Bplasma diff --git a/desc/magnetic_fields/__init__.py b/desc/magnetic_fields/__init__.py index 49115a2a3d..8f2498926c 100644 --- a/desc/magnetic_fields/__init__.py +++ b/desc/magnetic_fields/__init__.py @@ -1,5 +1,7 @@ """Classes for Magnetic Fields.""" +from desc.compute._laplace import Options + from ._core import ( MagneticFieldFromUser, OmnigenousField, @@ -21,3 +23,4 @@ solve_regularized_surface_current, ) from ._dommaschk import DommaschkPotentialField, dommaschk_potential +from ._laplace import FreeSurfaceOuterField, SourceFreeField diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index c063c50b71..9d3c5d9f7b 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -357,7 +357,7 @@ def compute_Bnormal( if None defaults to a LinearGrid with twice the surface poloidal and toroidal resolutions points are in surface angular coordinates i.e theta and zeta - source_grid : Grid, int or None + source_grid : Grid or int or None Grid used to discretize MagneticField object if calculating B from Biot-Savart. Should NOT include endpoint at 2pi. vc_source_grid : LinearGrid @@ -3064,7 +3064,7 @@ def compute( Returns ------- - data : dict of ndarray + data : dict[str, jnp.ndarray] Computed quantity and intermediate variables. """ diff --git a/desc/magnetic_fields/_current_potential.py b/desc/magnetic_fields/_current_potential.py index fd4bfc85ac..bde5ed846e 100644 --- a/desc/magnetic_fields/_current_potential.py +++ b/desc/magnetic_fields/_current_potential.py @@ -8,7 +8,7 @@ import skimage.measure from scipy.constants import mu_0 -from desc.backend import cho_factor, cho_solve, fori_loop, jnp +from desc.backend import cho_factor, cho_solve, jnp from desc.basis import DoubleFourierSeries from desc.compute.utils import _compute as compute_fun from desc.derivatives import Derivative @@ -31,6 +31,7 @@ xyz2rpz_vec, ) +from ..integrals.quad_utils import nfp_loop from ._core import ( _MagneticField, biot_savart_general, @@ -45,7 +46,8 @@ class CurrentPotentialField(_MagneticField, FourierRZToroidalSurface): where: - n is the winding surface unit normal. - - Phi is the current potential function, which is a function of theta and zeta. + - Ξ¦ is the current potential function, which is a function of theta and zeta. + - βˆ‡Ξ¦ dot n is assumed to be zero. This function then uses biot-savart to find the B field from this current density K on the surface. @@ -413,9 +415,10 @@ class FourierCurrentPotentialField(_MagneticField, FourierRZToroidalSurface): where: - n is the winding surface unit normal. - - Phi is the current potential function, which is a function of theta and zeta, + - Ξ¦ is the current potential function, which is a function of theta and zeta, and is given as a secular linear term in theta/zeta and a double Fourier series in theta/zeta. + - βˆ‡Ξ¦ dot n is assumed to be zero. This class then uses biot-savart to find the B field from this current density K on the surface. @@ -862,10 +865,13 @@ def to_CoilSet( # noqa: C901 - FIXME: simplify this Ξ¦(ΞΈ,ΞΆ) = Ξ¦β‚›α΅₯(ΞΈ,ΞΆ) + GΞΆ/2Ο€ + IΞΈ/2Ο€ - where n is the winding surface unit normal, Ξ¦ is the current potential - function, which is a function of theta and zeta, and is given as a - secular linear term in theta (I) and zeta (G) and a double Fourier - series in theta/zeta. + where: + + - n is the winding surface unit normal. + - Ξ¦ is the current potential function, which is a function of theta and zeta, + and is given as a secular linear term in theta (I) and zeta (G) and a double + Fourier series in theta/zeta. + - βˆ‡Ξ¦ dot n is assumed to be zero. NOTE: The function is not jit/AD compatible @@ -1129,30 +1135,23 @@ def _compute_A_or_B_from_CurrentPotentialField( profiles={}, ) - _rs = data["x"] - _K = data["K"] - + R = data["x"][:, 0] + Z = data["x"][:, 2] # surface element, must divide by NFP to remove the NFP multiple on the # surface grid weights, as we account for that when doing the for loop # over NFP - _dV = source_grid.weights * data["|e_theta x e_zeta|"] / source_grid.NFP + dV = source_grid.weights * data["|e_theta x e_zeta|"] / source_grid.NFP - def nfp_loop(j, f): - # calculate (by rotating) rs, rs_t, rz_t - phi = (source_grid.nodes[:, 2] + j * 2 * jnp.pi / source_grid.NFP) % ( - 2 * jnp.pi - ) - # new coords are just old R,Z at a new phi (bc of discrete NFP symmetry) - rs = jnp.vstack((_rs[:, 0], phi, _rs[:, 2])).T - rs = rpz2xyz(rs) - K = rpz2xyz_vec(_K, phi=phi) - fj = op(coords, rs, K, _dV, chunk_size=chunk_size) - f += fj - return f - - B = fori_loop(0, source_grid.NFP, nfp_loop, jnp.zeros_like(coords)) + def func(zeta_j): + rs = rpz2xyz(jnp.column_stack([R, zeta_j, Z])) + K = rpz2xyz_vec(data["K"], phi=zeta_j) + return op(coords, rs, K, dV, chunk_size=chunk_size) + + B = nfp_loop(source_grid, func, jnp.zeros_like(coords)) if basis == "rpz": B = xyz2rpz_vec(B, x=coords[:, 0], y=coords[:, 1]) + else: + assert basis == "xyz" return B @@ -1711,6 +1710,7 @@ def _find_current_potential_contours( - Ξ¦ is the current potential function, which is a function of theta and zeta, and is given as a secular linear term in theta (I) and zeta (G) and a double Fourier series in theta/zeta. + - βˆ‡Ξ¦ dot n is assumed to be zero. Parameters ---------- diff --git a/desc/magnetic_fields/_laplace.py b/desc/magnetic_fields/_laplace.py new file mode 100644 index 0000000000..288187d48e --- /dev/null +++ b/desc/magnetic_fields/_laplace.py @@ -0,0 +1,401 @@ +"""High order accurate multiply connected geometry Laplace solver as described in [1]_. + +References +---------- +.. [1] Unalmis et al. New high-order accurate free surface stellarator + equilibria optimization and boundary integral methods in DESC. + +""" + +from desc.basis import DoubleFourierSeries +from desc.geometry import FourierRZToroidalSurface +from desc.integrals.singularities import get_interpolator +from desc.magnetic_fields import ToroidalMagneticField +from desc.utils import errorif, setdefault, warnif + + +class SourceFreeField(FourierRZToroidalSurface): + """Compute source free magnetic fields. + + Implements the Neumann formulation in multiply connected + geometry described in [1]_. + + Let 𝒳 be an open set with continuously differentiable + closed boundary βˆ‚π’³. This class solves the following + partial differential equation for + varphi = Ο† = Ξ¦ (periodic) = ``Phi (periodic)``. + + - βˆ†Ο†(x) = 0 x ∈ 𝒳 + - (B - βˆ‡Ο† - Bβ‚€)(x) = 0 x ∈ 𝒳 + - n dot (βˆ‡Ο† + Bβ‚€)(x) = 0 x ∈ βˆ‚π’³ + - n dot B(x) = 0 x ∈ βˆ‚π’³ + - curl (B - Bβ‚€)(x) = 0 x βˆ‰ βˆ‚π’³ + - div B(x) = 0 βˆ€x + + Parameters + ---------- + surface : Surface + Geometry defining βˆ‚π’³. + M : int + Poloidal Fourier resolution to interpolate potential on βˆ‚π’³. + N : int + Toroidal Fourier resolution to interpolate potential on βˆ‚π’³. + NFP : int + Field periodicity of potential on βˆ‚π’³. + Default is ``surface.NFP`` which is correct only if + the globally defined part of ``B0`` produces an NFP periodic + field. + sym : str + Symmetry for Fourier basis interpolating the periodic part of the + potential. Default is ``False``. + B0 : _MagneticField + Magnetic field due to currents in 𝒳 and net currents outside 𝒳 + I : float + Net toroidal current determining a circulation of Ξ¦ (not Ο†). + Default is zero. + Y : float + Net poloidal current determining a circulation of Ξ¦ (not Ο†). + Default is zero. + + """ + + _immediate_attributes_ = ["_surface", "_Phi_basis", "_B0", "I", "Y"] + + def __init__( + self, + surface, + M, + N, + NFP=None, + sym=False, + B0=None, + I=0.0, # noqa: E741 + Y=0.0, + ): + self._surface = surface + self._Phi_basis = DoubleFourierSeries( + M=M, N=N, NFP=setdefault(NFP, surface.NFP), sym=sym + ) + self.I = I + self.Y = Y + self._B0 = B0 + + def __getattr__(self, attr): + return getattr(self._surface, attr) + + def __setattr__(self, name, value): + if name in SourceFreeField._immediate_attributes_: + object.__setattr__(self, name, value) + else: + setattr(object.__getattribute__(self, "_surface"), name, value) + + def __hasattr__(self, attr): + return hasattr(self, attr) or hasattr(self._surface, attr) + + @property + def surface(self): + """Surface geometry defining boundary.""" + return self._surface + + @property + def Phi_basis(self): + """DoubleFourierSeries: Basis for periodic part of potential.""" + return self._Phi_basis + + @property + def sym_Phi(self): + """str: Type of symmetry of periodic part of Phi (no symmetry if False).""" + return self._Phi_basis.sym + + @property + def M_Phi(self): + """int: Poloidal resolution of periodic part of Phi.""" + return self._Phi_basis.M + + @property + def N_Phi(self): + """int: Toroidal resolution of periodic part of Phi.""" + return self._Phi_basis.N + + def compute( + self, + names, + grid, + params=None, + transforms=None, + data=None, + RpZ_data=None, + RpZ_grid=None, + override_grid=True, + **kwargs, + ): + """Compute the quantity given by name on grid. + + Parameters + ---------- + names : str or array-like of str + Name(s) of the quantity(s) to compute. + grid : Grid + Grid of coordinates on which to perform computation. + params : dict[str, jnp.ndarray] + Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc + Defaults to attributes of self. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from ``grid``. + data : dict[str, jnp.ndarray] + Data computed so far, generally output from other compute functions. + Any vector v = vΒΉ RΜ‚ + vΒ² Ο•Μ‚ + vΒ³ ZΜ‚ should be given in components + v = [vΒΉ, vΒ², vΒ³] where RΜ‚, Ο•Μ‚, ZΜ‚ are the normalized basis vectors + of the cylindrical coordinates R, Ο•, Z. + RpZ_data : dict[str, jnp.ndarray] + Data evaluated so far on the (R, Ο•, Z) coordinates in this dictionary. + Should store the three entries ``"R"``, ``"phi"``, and ``"Z"`` + if the intention is to compute something at these coordinates. + If not given, then computes from ``RpZ_grid``. + RpZ_grid : Grid + Grid of coordinates on which to evaluate quantities that support + evaluation off of ``grid``. + If not given, then default is ``grid``. + override_grid : bool + If True, override ``grid`` if necessary and use a full + resolution grid to compute quantities and then downsample to ``grid``. + If False, uses only ``grid``, which may lead to + inaccurate values for surface or volume averages. + + Returns + ------- + data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on ``grid``. + RpZ_data : dict[str, jnp.ndarray] + Quantities and intermediate variables computed on the + (R, Ο•, Z) coordinates in ``RpZ_data``. + + """ + errorif( + self.M_Phi > grid.M, msg=f"Got M_Phi = {self.M_Phi} > {grid.M} = grid.M." + ) + errorif( + self.N_Phi > grid.N, msg=f"Got N_Phi = {self.N_Phi} > {grid.N} = grid.N." + ) + + kwargs.setdefault("B0", self._B0) + + # to simplify computation of a singular integral for βˆ‡Ο† + if kwargs.get("on_boundary", False) and "eval_interpolator" not in kwargs: + if RpZ_grid is None: + errorif(RpZ_data is not None, msg="Please supply RpZ_grid.") + else: + kwargs["eval_interpolator"] = get_interpolator( + eval_grid=RpZ_grid, + source_grid=grid, + source_data=super().compute( + ["|e_theta x e_zeta|", "e_theta", "e_zeta"], + grid, + params, + transforms, + data, + override_grid, + **kwargs, + ), + **kwargs, + ) + + if RpZ_data is None: + if RpZ_grid is None: + RpZ_grid = grid + RpZ_data = data + same_grid = True + else: + same_grid = False + RpZ_data = super().compute( + ["R", "phi", "Z"], + RpZ_grid, + params, + transforms, + data=RpZ_data, + override_grid=override_grid, + **kwargs, + ) + if same_grid: + data = RpZ_data + + return super().compute( + names, + grid, + params, + transforms, + data, + override_grid, + RpZ_data=RpZ_data, + **kwargs, + ) + + +class FreeSurfaceOuterField(SourceFreeField): + """Compute field on outer plasma for free surface. + + Implements the interior Dirichlet formulation in multiply connected + geometry described in [1]_. + + Parameters + ---------- + surface : Surface + Geometry defining βˆ‚π’³. + M : int + Poloidal Fourier resolution to interpolate potential on βˆ‚π’³. + N : int + Toroidal Fourier resolution to interpolate potential on βˆ‚π’³. + sym : str + Symmetry for Fourier basis interpolating the periodic part of the + potential. Default is ``sin`` when the surface is stellarator + symmetric and ``False`` otherwise. + M_coil : int + Poloidal Fourier resolution to interpolate coil potential on βˆ‚π’³. + Default is ``M``. + N_coil : int + Poloidal Fourier resolution to interpolate coil potential on βˆ‚π’³. + Default is ``N``. + sym_coil : str + Symmetry for Fourier basis interpolating the periodic part of the + coil potential. Default is ``sym``. + B_coil : _MagneticField + Magnetic field from coil current sources. + This must be smooth and divergence free for correctness. + Y_coil : float + Net poloidal current determining circulation of coil field. + Default is to compute from ``B_coil``. + I_plasma : float + Net toroidal plasma current determining a circulation of Ξ¦. + Default is zero. + I_sheet : float + Net toroidal sheet current determining a circulation of Ξ¦. + Default is zero. + + """ + + _immediate_attributes_ = ["_Phi_coil_basis", "_B_coil"] + + def __init__( + self, + surface, + M, + N, + sym=None, + M_coil=None, + N_coil=None, + sym_coil=None, + B_coil=None, + Y_coil=None, + I_plasma=0.0, + I_sheet=0.0, + ): + sym = setdefault(sym, "sin" if surface.sym else False) + I = I_plasma + I_sheet # noqa: E741 + + super().__init__( + surface, + M, + N, + surface.NFP, + sym, + FreeSurfaceOuterField._B0(I, Y_coil), + I, + Y_coil, + ) + if M_coil is None and N_coil is None and sym_coil is None: + self._Phi_coil_basis = self._Phi_basis + else: + self._Phi_coil_basis = DoubleFourierSeries( + M=setdefault(M_coil, M), + N=setdefault(N_coil, N), + NFP=surface.NFP, + sym=setdefault(sym_coil, sym), + ) + self._B_coil = B_coil + + @staticmethod + def _B0(I, Y): # noqa: E741 + """Returns βˆ‡(Ξ¦ (secular)).""" + warnif( + I != 0, + NotImplementedError, + "Must supply B0 as kwarg in compute method for correctness.", + ) + return ToroidalMagneticField(setdefault(Y, 0), 1) + + def __setattr__(self, name, value): + if ( + name in FreeSurfaceOuterField._immediate_attributes_ + or name in SourceFreeField._immediate_attributes_ + ): + object.__setattr__(self, name, value) + else: + setattr(object.__getattribute__(self, "_surface"), name, value) + + @property + def Phi_coil_basis(self): + """DoubleFourierSeries: Basis for periodic part of coil potential.""" + return self._Phi_coil_basis + + @property + def sym_Phi_coil(self): + """str: Symmetry of periodic part of Phi_coil (no symmetry if False).""" + return self._Phi_coil_basis.sym + + @property + def M_Phi_coil(self): + """int: Poloidal resolution of periodic part of Phi_coil.""" + return self._Phi_coil_basis.M + + @property + def N_Phi_coil(self): + """int: Toroidal resolution of periodic part of Phi_coil.""" + return self._Phi_coil_basis.N + + def compute( + self, + names, + grid, + params=None, + transforms=None, + data=None, + RpZ_data=None, + RpZ_grid=None, + override_grid=True, + **kwargs, + ): + """Compute the quantity given by name on grid.""" + errorif( + self.M_Phi_coil > grid.M, + msg=f"Got M_Phi_coil = {self.M_Phi_coil} > {grid.M} = grid.M.", + ) + errorif( + self.N_Phi_coil > grid.N, + msg=f"Got N_Phi_coil = {self.N_Phi_coil} > {grid.N} = grid.N.", + ) + kwargs.setdefault("B_coil", self._B_coil) + if self.Y is None and (params is None or "Y" not in params): + data, RpZ_data = super().compute( + "Y_coil", + grid, + params, + transforms, + data, + RpZ_data, + RpZ_grid, + override_grid, + **kwargs, + ) + params = setdefault(params, {}) + params["Y"] = data["Y_coil"] + return super().compute( + names, + grid, + params, + transforms, + data, + RpZ_data, + RpZ_grid, + override_grid, + **kwargs, + ) diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index 52ca059e98..3cd042ffad 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -27,7 +27,7 @@ RadialForceBalance, ) from ._fast_ion import GammaC -from ._free_boundary import BoundaryError, VacuumBoundaryError +from ._free_boundary import BoundaryError, FreeSurfaceError, VacuumBoundaryError from ._generic import ( DeflationOperator, ExternalObjective, diff --git a/desc/objectives/_coils.py b/desc/objectives/_coils.py index edfbdf8f2c..83b73b30a9 100644 --- a/desc/objectives/_coils.py +++ b/desc/objectives/_coils.py @@ -1832,7 +1832,7 @@ class SurfaceQuadraticFlux(_Objective): coils). Default grid is determined by the specific MagneticField object, see the docs of that object's ``compute_magnetic_field`` method for more detail. field_fixed : bool - Whether or not to fix the magnetic field's DOFs during the optimization. + Whether to fix the magnetic field's DOFs during the optimization. bs_chunk_size : int or None Size to split Biot-Savart computation into chunks of evaluation points. If no chunking should be done or the chunk size is the full input diff --git a/desc/objectives/_free_boundary.py b/desc/objectives/_free_boundary.py index b62603289d..5edda32af0 100644 --- a/desc/objectives/_free_boundary.py +++ b/desc/objectives/_free_boundary.py @@ -1,28 +1,54 @@ """Objectives for solving free boundary equilibria.""" import numpy as np +from jax.lax import stop_gradient from scipy.constants import mu_0 from desc.backend import jnp from desc.compute import get_profiles, get_transforms +from desc.compute._laplace import Options as LaplaceOptions from desc.compute.utils import _compute as compute_fun from desc.grid import LinearGrid -from desc.integrals import DFTInterpolator, FFTInterpolator, virtual_casing_biot_savart +from desc.integrals import get_interpolator, virtual_casing_biot_savart +from desc.io import IOAble from desc.nestor import Nestor from desc.objectives.objective_funs import _Objective, collect_docs +from desc.optimizable import Optimizable, optimizable_parameter from desc.utils import ( PRINT_WIDTH, Timer, + cross, + dot, errorif, parse_argname_change, setdefault, warnif, ) -from ..integrals.singularities import best_params, best_ratio from .normalization import compute_scaling_factors +class _FreeSurfaceSheetCurrent(IOAble, Optimizable): + """Optimizable toroidal sheet-current parameter for FreeSurfaceError.""" + + _io_attrs_ = ["_I_sheet"] + + def __init__(self): + self.I_sheet = 0.0 + + @optimizable_parameter + @property + def I_sheet(self): + """float: Net toroidal sheet current determining a circulation of Phi.""" + return self._I_sheet + + @I_sheet.setter + def I_sheet(self, new): + new = jnp.asarray(new) + assert new.size == 1 + self._I_sheet = new.squeeze() + + class VacuumBoundaryError(_Objective): """Target for free boundary conditions on LCFS for vacuum equilibrium. @@ -545,51 +571,20 @@ def build(self, use_jit=True, verbose=1): else: source_grid = self._source_grid - if self._eval_grid is None: - eval_grid = source_grid - else: - eval_grid = self._eval_grid - + eval_grid = setdefault(self._eval_grid, source_grid) self._use_same_grid = eval_grid.equiv(source_grid) - errorif( - not np.all(source_grid.nodes[:, 0] == 1.0), - ValueError, - "source_grid contains nodes not on rho=1", + ratio_data = ( + eq.compute(["|e_theta x e_zeta|", "e_theta", "e_zeta"], grid=source_grid) + if (self._st is None or self._sz is None or self._q is None) + else {} ) - errorif( - not np.all(eval_grid.nodes[:, 0] == 1.0), - ValueError, - "eval_grid contains nodes not on rho=1", + interpolator = get_interpolator( + eval_grid, source_grid, ratio_data, st=self._st, sz=self._sz, q=self._q ) - errorif( - source_grid.sym, - ValueError, - "Source grids for singular integrals must be non-symmetric", - ) - - if self._st is None or self._sz is None or self._q is None: - ratio_data = eq.compute( - ["|e_theta x e_zeta|", "e_theta", "e_zeta"], grid=source_grid - ) - st, sz, q = best_params(source_grid, best_ratio(ratio_data)) - self._st = setdefault(self._st, st) - self._sz = setdefault(self._sz, sz) - self._q = setdefault(self._q, q) - - try: - interpolator = FFTInterpolator( - eval_grid, source_grid, self._st, self._sz, self._q - ) - except AssertionError as e: - warnif( - True, - msg="Could not build fft interpolator, switching to dft which is slow." - "\nReason: " + str(e), - ) - interpolator = DFTInterpolator( - eval_grid, source_grid, self._st, self._sz, self._q - ) + del self._st + del self._sz + del self._q edge_pres = np.max(np.abs(eq.compute("p", grid=eval_grid)["p"])) warnif( @@ -779,7 +774,7 @@ def compute(self, eq_params, *field_params, constants=None): ) Bjump = Bex_total - Bin_total if self._sheet_current: - Kerr = mu_0 * sheet_eval_data["K"] - jnp.cross(eval_data["n_rho"], Bjump) + Kerr = mu_0 * sheet_eval_data["K"] - cross(eval_data["n_rho"], Bjump) Kerr = jnp.linalg.norm(Kerr, axis=-1) * g return jnp.concatenate([Bn_err, Bsq_err, Kerr]) else: @@ -901,6 +896,505 @@ def _print(fmt, fmax, fmin, fmean, f0max, f0min, f0mean, norm, unit): return out +class FreeSurfaceError(_Objective): + """Target for free surface ideal MHD equilirium as described in [1]_. + + References + ---------- + .. [1] Unalmis et al. New high-order accurate free surface stellarator + equilibria optimization and boundary integral methods in DESC. + + Notes + ----- + Performance is expected to improve significantly by resolving GitHub + issues #1034 and #2171. + + If reverse mode differentiation is being used, it is of great benefit for + the objective residual to be a lower dimensional item. In such cases, it is + better to instead use a loss function that reduces the dimension of the + residual before computing the derivative relevant for optimization. This can + be a mean squared error over all points of the output grid or mean absolute + error over blocks of the grid, etc. + + Parameters + ---------- + eq : Equilibrium + ``Equilibrium`` to be optimized. + field : FreeSurfaceOuterField or SourceFreeField + Laplace solver object. + + If is an instance of ``FreeSurfaceOuterField`` + assumes ``field._B_coil`` is the magnetic field due to coils. + If is an instance of ``SourceFreeField`` then assumes ``field._B0`` is + the magnetic field due to coils. + + The net toroidal sheet current ``I_sheet`` is an optimizable scalar + parameter initialized to zero. + grid : Grid + Grid for the integral transforms. + Tensor-product grid in (ΞΈ, ΞΆ) with uniformly spaced nodes + (ΞΈ, ΞΆ) ∈ [0, 2Ο€) Γ— [0, 2Ο€/NFP) on the boundary. + Default is ``LinearGrid(M=eq.M_grid,N=eq.N_grid,NFP=eq.NFP)``. + coil_grid : Grid, optional + Source grid used to discretize coil magnetic field computation. + Default is default grid of coil magnetic field. + q : int + Order of integration on the local singular grid. + fix_I_sheet : bool, optional + Whether to fix the net toroidal sheet current to zero instead of optimizing it. + options : LaplaceOptions + Options for the Laplace solver. + + """ + + __doc__ = __doc__.rstrip() + collect_docs( + target_default="``target=0``.", + bounds_default="``target=0``.", + ) + + _scalar = False + _print_value_fmt = "Free surface Error: " + _units = "T^2 m^2" + + _static_attrs = _Objective._static_attrs + [ + "_is_neumann", + "_field", + "_B_coil", + "_use_same_grid", + "_q", + "_fix_I_sheet", + "_options", + "_grad_keys", + "_inner_keys", + "_reuseable_keys", + ] + + _coordinates = "rtz" + + def __init__( + self, + eq, + field, + *, + grid=None, + coil_grid=None, + q=None, + fix_I_sheet=False, + options=None, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + jac_chunk_size=None, + name="Free surface error", + **kwargs, + ): + if target is None and bounds is None: + target = 0.0 + assert fix_I_sheet in {True, False} + + if grid is None: + grid = LinearGrid( + rho=np.array([1.0]), + M=eq.M_grid, + N=eq.N_grid, + NFP=eq.NFP if eq.N > 0 else 64, + sym=False, + ) + assert grid.can_fft2 + errorif(field.M_Phi > grid.M, msg=f"M_Phi = {field.M_Phi} > {grid.M} = grid.M.") + errorif(field.N_Phi > grid.N, msg=f"N_Phi = {field.N_Phi} > {grid.N} = grid.N.") + + self._is_neumann = not hasattr(field, "M_Phi_coil") + errorif( + not self._is_neumann and field.M_Phi_coil > grid.M, + msg=f"M_Phi_coil = {getattr(field, 'M_Phi_coil', 0)} > {grid.M} = grid.M.", + ) + errorif( + not self._is_neumann and field.N_Phi_coil > grid.N, + msg=f"N_Phi_coil = {getattr(field, 'N_Phi_coil', 0)} > {grid.N} = grid.N.", + ) + eval_grid = ( + grid + if (grid.M == field.M_Phi and grid.N == field.N_Phi) + else LinearGrid(M=field.M_Phi, N=field.N_Phi, NFP=grid.NFP, sym=False) + ) + assert eval_grid.can_fft2 + + errorif( + field.M_Phi != eval_grid.M, + msg=f"M_Phi = {field.M_Phi} != {eval_grid.M} = eval_grid.M.", + ) + errorif( + field.N_Phi != eval_grid.N, + msg=f"N_Phi = {field.N_Phi} != {eval_grid.N} = eval_grid.N.", + ) + errorif( + not self._is_neumann and field.M_Phi_coil > eval_grid.M, + msg=( + f"M_Phi_coil = {getattr(field, 'M_Phi_coil', 0)} > " + f"{eval_grid.M} = eval_grid.M." + ), + ) + errorif( + not self._is_neumann and field.N_Phi_coil > eval_grid.N, + msg=( + f"N_Phi_coil = {getattr(field, 'N_Phi_coil', 0)} > " + f"{eval_grid.N} = eval_grid.N." + ), + ) + + self._field = field + self._B_coil = field._B0 if self._is_neumann else field._B_coil + self._grid = grid + self._eval_grid = eval_grid + self._coil_grid = coil_grid + self._use_same_grid = grid.equiv(eval_grid) + self._q = q + self._fix_I_sheet = fix_I_sheet + I_sheet = _FreeSurfaceSheetCurrent() + if options is None: + options = LaplaceOptions() + else: + options = LaplaceOptions(*options) + options = options._replace( + problem="exterior Neumann" if self._is_neumann else "interior Dirichlet" + ) + self._options = tuple(options) # DESC is dumb and casts NamedTuples to Tuples + self._grad_keys = ["grad(theta)", "grad(zeta)", "n_rho"] + self._inner_keys = [ + "|B|^2", + "p", + "I", + "R", + "phi", + "omega", + "Z", + "|e_theta x e_zeta|", + ] + self._grad_keys + self._reuseable_keys = [ + "0", + "R", + "phi", + "omega", + "R_t", + "R_z", + "Z", + "Z_t", + "Z_z", + "e_theta", + "e_theta x e_zeta", + "e_zeta", + "n_rho", + "omega_t", + "omega_z", + "|e_theta x e_zeta|", + ] + + things = [eq] if fix_I_sheet else [eq, I_sheet] + super().__init__( + things=things, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + name=name, + jac_chunk_size=jac_chunk_size, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + eq = self.things[0] + options = LaplaceOptions(*self._options) + + eq_transforms = get_transforms(self._inner_keys, eq, grid=self._eval_grid) + eval_transforms = get_transforms("|K_vc|^2", self._field, grid=self._eval_grid) + if self._use_same_grid: + source_transforms = eval_transforms + grad_transforms = eq_transforms + else: + source_transforms = get_transforms("Phi_mn", self._field, grid=self._grid) + grad_transforms = get_transforms( + self._grad_keys + ["phi", "omega", "Z"], eq, grid=self._grid + ) + + data, _ = self._field.compute( + ["interpolator"] if self._is_neumann else ["interpolator", "Y_coil"], + grid=self._grid, + q=self._q, + transforms=source_transforms, + B_coil=self._B_coil, + options=options, + potential_grid=self._eval_grid, + ) + # No net poloidal current in equation 4.13 of [1]. + self._field.Y = 0.0 if self._is_neumann else data["Y_coil"] + profiles = get_profiles(self._inner_keys, eq, grid=self._eval_grid) + initial_guess = self._compute_initial_guess( + eq, + source_transforms, + eval_transforms, + profiles, + data["interpolator"], + None if self._fix_I_sheet else self.things[1].params_dict, + ) + + self._constants = { + "interpolator": data["interpolator"], + "eq_transforms": eq_transforms, + "grad_transforms": grad_transforms, + "eval_transforms": eval_transforms, + "source_transforms": source_transforms, + "profiles": profiles, + "initial_guess": initial_guess, + "quad_weights": np.sqrt(eval_transforms["grid"].weights), + } + self._dim_f = self._eval_grid.num_nodes + + if self._normalize: + scales = compute_scaling_factors(eq) + self._normalization = ( + np.ones(self._eval_grid.num_nodes) + * scales["B"] ** 2 + * scales["R0"] + * scales["a"] + ) + + super().build(use_jit=use_jit, verbose=verbose) + + def _compute_initial_guess( + self, + eq, + source_transforms, + eval_transforms, + profiles, + interpolator, + I_sheet_params=None, + ): + """Compute the potential used to initialize iterative solves.""" + options = LaplaceOptions(*self._options) + params = eq.params_dict + I_sheet = 0.0 if I_sheet_params is None else I_sheet_params["I_sheet"][0] + source_grid = self._grid + source_keys = self._reuseable_keys + ["grad(theta)", "grad(zeta)", "I"] + source_data = eq.compute( + source_keys, + grid=source_grid, + ) + field_params = { + "R_lmn": params["Rb_lmn"], + "Z_lmn": params["Zb_lmn"], + "I": source_data["I"][source_grid.unique_rho_idx[-1]] + I_sheet, + "Y": self._field.Y, + } + data = {key: source_data[key] for key in self._reuseable_keys} + data["interpolator"] = interpolator + if not self._use_same_grid: + data["potential data"] = eq.compute(["R", "phi", "Z"], grid=self._eval_grid) + data["B0*n"] = self._phi_sec_dot_n(field_params, source_data) + if self._is_neumann: + data, _ = self._field.compute( + "B_coil", + grid=source_grid, + params=field_params, + transforms=source_transforms, + data=data, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + data["B0*n"] += dot(data["B_coil"], data["n_rho"]) + elif not self._use_same_grid: + potential_field_data, _ = self._field.compute( + "Phi_coil (periodic)", + grid=self._eval_grid, + params=field_params, + transforms=eval_transforms, + data={"Y_coil": self._field.Y}, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + data["Phi_coil (periodic)"] = potential_field_data["Phi_coil (periodic)"] + + data = compute_fun( + self._field, + "Phi (periodic)", + field_params, + eval_transforms, + profiles, + data=data, + options=options._replace(solve_method="gmres"), + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + # We differentiate through the solution, not the initial guess, + # so we stop the gradient for numerical stability. + return stop_gradient(data["Phi (periodic)"]) + + def compute(self, params, I_sheet_params=None, constants=None): + """Compute boundary error. + + Parameters + ---------- + params : dict + Dictionary of equilibrium degrees of freedom, e.g. + ``Equilibrium.params_dict``. + I_sheet_params : dict + Dictionary containing the optimizable sheet current ``I_sheet``. If omitted, + the sheet current is fixed to zero. + constants : dict + Dictionary of constant data, e.g. transforms, profiles etc. + Defaults to ``self.constants``. + + Returns + ------- + f : ndarray + Boundary error [[BΒ² + 2ΞΌβ‚€p]]*area Jacobian in TΒ² mΒ². + + """ + constants = self._get_deprecated_constants(constants) + eq = self.things[0] + I_sheet = 0.0 if I_sheet_params is None else I_sheet_params["I_sheet"][0] + options = LaplaceOptions(*self._options)._replace( + Phi_0=constants["initial_guess"] + ) + + inner = compute_fun( + eq, + self._inner_keys, + params, + constants["eq_transforms"], + constants["profiles"], + ) + field_params = { + "R_lmn": params["Rb_lmn"], + "Z_lmn": params["Zb_lmn"], + # This is I_plasma + I_sheet. + "I": inner["I"][self._eval_grid.unique_rho_idx[-1]] + I_sheet, + "Y": self._field.Y, + } + outer = {key: inner[key] for key in self._reuseable_keys} + + if self._is_neumann: + outer = compute_fun( + self._field, + "B_coil", + field_params, + constants["eval_transforms"], + constants["profiles"], + data=outer, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + elif not self._use_same_grid: + outer = compute_fun( + self._field, + "Phi_coil (periodic)", + field_params, + constants["eval_transforms"], + constants["profiles"], + data=outer, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + + potential_data = ( + None + if self._use_same_grid + else {key: inner[key] for key in ["R", "phi", "Z"]} + ) + + if self._use_same_grid: + outer["interpolator"] = constants["interpolator"] + outer["B0*n"] = self._phi_sec_dot_n(field_params, inner) + if self._is_neumann: + outer["B0*n"] += dot(outer["B_coil"], inner["n_rho"]) + else: + grads = compute_fun( + eq, + self._grad_keys + ["phi", "omega", "Z"], + params, + constants["grad_transforms"], + constants["profiles"], + ) + data = {key: grads[key] for key in self._reuseable_keys} + data["interpolator"] = constants["interpolator"] + if potential_data is not None: + data["potential data"] = potential_data + if not self._is_neumann: + data["Phi_coil (periodic)"] = outer["Phi_coil (periodic)"] + data["B0*n"] = self._phi_sec_dot_n(field_params, grads) + if self._is_neumann: + data = compute_fun( + self._field, + "B_coil", + field_params, + constants["source_transforms"], + constants["profiles"], + data=data, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + data["B0*n"] += dot(data["B_coil"], data["n_rho"]) + + outer["Phi_mn"] = compute_fun( + self._field, + "Phi_mn", + field_params, + constants["eval_transforms"], + constants["profiles"], + data=data, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + )["Phi_mn"] + + outer = compute_fun( + self._field, + ["K_vc", "n_rho x B_coil"] if self._is_neumann else "|K_vc|^2", + field_params, + constants["eval_transforms"], + constants["profiles"], + data=outer, + options=options, + B_coil=self._B_coil, + field_grid=self._coil_grid, + ) + if self._is_neumann: + outer["K_vc"] -= outer["n_rho x B_coil"] + outer["|K_vc|^2"] = dot(outer["K_vc"], outer["K_vc"]) + + return (outer["|K_vc|^2"] - inner["|B|^2"] - 2 * mu_0 * inner["p"]) * inner[ + "|e_theta x e_zeta|" + ] + + @staticmethod + def _phi_sec_dot_n(params, grads): + return dot( + params["I"] * grads["grad(theta)"] + params["Y"] * grads["grad(zeta)"], + grads["n_rho"], + ) + + class BoundaryErrorNESTOR(_Objective): """Pressure balance across LCFS. diff --git a/desc/utils.py b/desc/utils.py index 34088ebb01..032fb7544d 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -888,18 +888,19 @@ def safenorm(x, ord=None, axis=None, fill=0, threshold=0, keepdims=False): """ is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = jnp.linalg.norm(y, ord=ord, axis=axis) - n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero - if keepdims: - axis = 0 if axis is None else axis - n = jnp.expand_dims(n, axis) + y = jnp.where(is_zero, jnp.ones_like(x), x) + if not keepdims: + is_zero = is_zero.squeeze(axis=axis) + n = jnp.linalg.norm(y, ord=ord, axis=axis, keepdims=keepdims) + n = jnp.where(is_zero, fill, n) return n def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): """Normalize a vector to unit length, but without nan gradient at x=0. + If x is zero returns a constant array of unit length. + Parameters ---------- x : ndarray @@ -915,10 +916,10 @@ def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): """ is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = safenorm(x, ord, axis, fill, threshold, keepdims=True) * jnp.ones_like(x) + y = jnp.where(is_zero, jnp.ones_like(x), x) + n = safenorm(x, ord, axis, fill, threshold, keepdims=True) # return unit vector with equal components if norm <= threshold - return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) + return jnp.where(n <= threshold, jnp.reciprocal(jnp.sqrt(x.size)), y / n) def safediv(a, b, fill=0, threshold=0): @@ -1160,7 +1161,7 @@ def apply(d, fun=identity, subset=None, exclude=None): elif isinstance(subset, str): subset = (subset,) exclude = () if (exclude is None) else exclude - return {k: fun(d[k]) for k in subset if k not in exclude} + return {k: fun(d[k]) for k in subset if (k in d and k not in exclude)} def get_ess_scale(modes, alpha=1.2, order=np.inf, min_value=1e-7): diff --git a/devtools/check_unmarked_tests.py b/devtools/check_unmarked_tests.py index 150a67d2cb..1c7a2a087b 100644 --- a/devtools/check_unmarked_tests.py +++ b/devtools/check_unmarked_tests.py @@ -8,7 +8,7 @@ import ast import sys -REQUIRED_MARKS = {"unit", "regression", "benchmark", "memory"} +REQUIRED_MARKS = {"unit", "regression", "benchmark", "memory", "skip"} def _pytest_marks(decorators): diff --git a/devtools/check_unmarked_tests.sh b/devtools/check_unmarked_tests.sh index 1fceb72e68..c027faf6c3 100755 --- a/devtools/check_unmarked_tests.sh +++ b/devtools/check_unmarked_tests.sh @@ -6,4 +6,20 @@ if [ -f "devtools/pre-commit.log" ]; then fi echo "Files to check: $@" +# Collect unmarked tests for the specific file and suppress errors +unmarked=$(pytest "$@" --collect-only -m "not unit and not regression and not benchmark and not memory and not skip" -q 2> /dev/null | head -n 2) + +# Count the number of unmarked tests found, ignoring empty lines and the line emitted if pytest found no unmarked tests +num_unmarked=$(echo "$unmarked" | sed '/^\s*$/d;/no tests collected/d' | wc -l) + +# If there are any unmarked tests, print them and exit with status 1 +if [ "$num_unmarked" -gt 0 ]; then + echo "----found $num_unmarked unmarked tests----" + echo "$unmarked" + # Calculate the elapsed time and print with a newline + end_time=$(date +%s) + elapsed_time=$((end_time - start_time)) + printf "\nTime taken: %d seconds" "$elapsed_time" + exit 1 +fi python devtools/check_unmarked_tests.py "$@" diff --git a/devtools/dev-requirements.txt b/devtools/dev-requirements.txt index e83a283120..e5c7dd2534 100644 --- a/devtools/dev-requirements.txt +++ b/devtools/dev-requirements.txt @@ -37,6 +37,7 @@ pytest-split >= 0.8.2, <= 0.11.0 qicna @ git+https://github.com/rogeriojorge/pyQIC/ qsc <= 0.1.3 shapely >= 1.8.2, <= 2.1.2 +optimistix # building build diff --git a/publications/unalmis2025/free_surface_error_profile.py b/publications/unalmis2025/free_surface_error_profile.py new file mode 100644 index 0000000000..569cb2377b --- /dev/null +++ b/publications/unalmis2025/free_surface_error_profile.py @@ -0,0 +1,47 @@ +"""Profile the FreeSurfaceError objective. + +Profiling requires python < 3.14. + - pip install xprof tensorboard tensorboard_plugin_profile + - cd DESC/publications/unalmis2025 + - python free_surface_error_profile.py + - tensorboard --logdir=/tmp/profile-data + +""" + +import numpy as np + +from desc.backend import jax +from desc.examples import get +from desc.grid import LinearGrid +from desc.magnetic_fields import FreeSurfaceOuterField, ToroidalMagneticField +from desc.objectives import ForceBalance, FreeSurfaceError, ObjectiveFunction +from desc.optimize import ProximalProjection + +eq = get("W7-X") +grid = LinearGrid(rho=np.array([1.0]), M=8, N=8, NFP=eq.NFP, sym=False) +B_coil = ToroidalMagneticField(5, 1) + +field = FreeSurfaceOuterField(eq.surface, M=8, N=8, B_coil=B_coil) +obj = ObjectiveFunction( + [ + FreeSurfaceError( + eq, + field, + grid=grid, + solve_method="gmres", + deriv_mode="fwd", + ) + ] +) +constraint = ObjectiveFunction([ForceBalance(eq)]) +prox = ProximalProjection( + obj, constraint, eq, solve_options={"solve_during_proximal_build": False} +) +prox.build() +x = prox.x(eq) + +err = prox.compute_scaled_error(x, prox.constants).block_until_ready() + +with jax.profiler.trace("/tmp/profile-data"): + with jax.profiler.TraceAnnotation("Benchmarking FreeSurfaceError"): + err = prox.compute_scaled_error(x, prox.constants).block_until_ready() diff --git a/requirements.txt b/requirements.txt index 110d8633ef..2c1614ca75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ diffrax >= 0.6.0, <= 0.7.2 equinox >=0.11.10, <=0.13.8 h5py >= 3.0.0, <= 3.16.0 interpax >= 0.3.3, < 0.4 -interpax_fft >= 0.0.6, <= 0.0.7 +interpax_fft >= 0.0.9, <= 0.0.9 +lineax jax-finufft >= 1.1.0, <= 1.3.1 matplotlib >= 3.7.3, <= 3.10.8 mpmath >= 1.0.0, <= 1.4.1 diff --git a/tests/inputs/master_compute_data_rpz.pkl b/tests/inputs/master_compute_data_rpz.pkl index a4ac2ba3aa..d8e5476f92 100644 Binary files a/tests/inputs/master_compute_data_rpz.pkl and b/tests/inputs/master_compute_data_rpz.pkl differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 1335503f31..816f26efa2 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -77,6 +77,7 @@ "|grad(theta)|", " Redl", # may not exist for all configurations "current Redl", + "n_rho x grad(theta)", "J^theta_PEST", "(J^theta_PEST_v)|PEST", "(J^theta_PEST_p)|PEST", diff --git a/tests/test_backend.py b/tests/test_backend.py index 3cf96ec130..f731bcbcfd 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -131,9 +131,7 @@ def test_lstsq(): # square A = rng.standard_normal((5, 5)) b = rng.standard_normal(5) - np.testing.assert_allclose( - _lstsq(A, b), np.linalg.lstsq(A, b, rcond=None)[0], rtol=1e-6 - ) + np.testing.assert_allclose(_lstsq(A, b), np.linalg.solve(A, b), rtol=1e-6) # scalar A = rng.standard_normal((1, 5)) b = rng.standard_normal(1) diff --git a/tests/test_basis.py b/tests/test_basis.py index 80fa8fc813..0740c237ec 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -266,9 +266,8 @@ def test_double_fourier(self): ).T basis = DoubleFourierSeries(M=1, N=1) - values = basis.evaluate(grid.nodes, derivatives=np.array([0, 0, 0])) - - np.testing.assert_allclose(values, correct_vals, atol=1e-8) + assert basis.num_modes == (2 * basis.M + 1) * (2 * basis.N + 1) + np.testing.assert_allclose(basis.evaluate(grid), correct_vals, atol=1e-8) @pytest.mark.unit def test_change_resolution(self): diff --git a/tests/test_compute_everything.py b/tests/test_compute_everything.py index 9bc8900e0b..316d4c44c7 100644 --- a/tests/test_compute_everything.py +++ b/tests/test_compute_everything.py @@ -193,7 +193,12 @@ def test_compute_everything(): current=5, X=[5, 10, 2, 5], Y=[1, 2, 3, 1], Z=[-4, -5, -6, -4] ), } - assert things.keys() == data_index.keys(), ( + same_compute_fun_as_surface = { + # not testing these here + "desc.magnetic_fields._laplace.FreeSurfaceOuterField", + "desc.magnetic_fields._laplace.SourceFreeField", + } + assert things.keys() == (data_index.keys() - same_compute_fun_as_surface), ( f"Missing the parameterization {data_index.keys() - things.keys()}" f" to test against master." ) @@ -243,7 +248,10 @@ def test_compute_everything(): for p in things: - names = set(data_index[p].keys()) + names = set(data_index[p].keys()).copy() + # not clear why need to discard since these should not be in data_index[p] + names.discard("potential data") + names.discard("interpolator") def need_special(name): return bool(data_index[p][name]["source_grid_requirement"]) or bool( @@ -255,6 +263,8 @@ def need_special(name): this_branch_data_rpz[p] = things[p].compute( list(names), **grid.get(p, {}), basis="rpz" ) + this_branch_data_rpz[p].pop("potential data", None) + this_branch_data_rpz[p].pop("interpolator", None) # make sure we can compute everything assert this_branch_data_rpz[p].keys() == names, ( f"Parameterization: {p}. Can't compute " @@ -288,6 +298,7 @@ def need_special(name): this_branch_data_xyz = things[p].compute( list(names_xyz), **grid.get(p, {}), basis="xyz" ) + this_branch_data_xyz.pop("potential data", None) assert this_branch_data_xyz.keys() == names_xyz, ( f"Parameterization: {p}. Can't compute " + f"{names_xyz - this_branch_data_xyz.keys()}." diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 39c6c59753..5c798a6385 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -1951,62 +1951,54 @@ def test_surface_equilibrium_geometry(): @pytest.mark.unit -def test_clebsch_sfl_funs(): +@pytest.mark.parametrize("eq", [get("W7-X"), get("NCSX")]) +def test_clebsch_sfl_funs(eq): """Test geometric and physical methods of computing B agree.""" - - def test(eq): - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(2, 2, 2, 4, 4, 4) - data = eq.compute( - [ - "e_zeta|r,a", - "B", - "B^zeta", - "B^phi", - "|B|_z|r,a", - "grad(|B|)", - "|e_zeta|r,a|_z|r,a", - "B^zeta_z|r,a", - "|B|", - "sqrt(g)_Clebsch", - "sqrt(g)_PEST", - "psi_r", - "grad(psi)", - "grad(alpha)", - "grad(phi)", - "B_phi", - "gbdrift (secular)", - "gbdrift (secular)/phi", - "phi", - ], - ) - np.testing.assert_allclose(data["e_zeta|r,a"], (data["B"].T / data["B^zeta"]).T) - np.testing.assert_allclose( - data["|B|_z|r,a"], dot(data["grad(|B|)"], data["e_zeta|r,a"]) - ) - np.testing.assert_allclose( - data["|e_zeta|r,a|_z|r,a"], - data["|B|_z|r,a"] / np.abs(data["B^zeta"]) - - data["|B|"] - * data["B^zeta_z|r,a"] - * np.sign(data["B^zeta"]) - / data["B^zeta"] ** 2, - ) - np.testing.assert_allclose( - data["B"], cross(data["grad(psi)"], data["grad(alpha)"]) - ) - np.testing.assert_allclose( - data["B^zeta"], data["psi_r"] / data["sqrt(g)_Clebsch"] - ) - np.testing.assert_allclose(data["B^phi"], data["psi_r"] / data["sqrt(g)_PEST"]) - np.testing.assert_allclose(data["B^phi"], dot(data["B"], data["grad(phi)"])) - np.testing.assert_allclose(data["B_phi"], data["B"][:, 1]) - np.testing.assert_allclose( - data["gbdrift (secular)"], data["gbdrift (secular)/phi"] * data["phi"] - ) - - test(get("W7-X")) - test(get("NCSX")) + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(2, 2, 2, 4, 4, 4) + data = eq.compute( + [ + "e_zeta|r,a", + "B", + "B^zeta", + "B^phi", + "|B|_z|r,a", + "grad(|B|)", + "|e_zeta|r,a|_z|r,a", + "B^zeta_z|r,a", + "|B|", + "sqrt(g)_Clebsch", + "sqrt(g)_PEST", + "psi_r", + "grad(psi)", + "grad(alpha)", + "grad(phi)", + "B_phi", + "gbdrift (secular)", + "gbdrift (secular)/phi", + "phi", + ], + ) + np.testing.assert_allclose(data["e_zeta|r,a"], (data["B"].T / data["B^zeta"]).T) + np.testing.assert_allclose( + data["|B|_z|r,a"], dot(data["grad(|B|)"], data["e_zeta|r,a"]) + ) + np.testing.assert_allclose( + data["|e_zeta|r,a|_z|r,a"], + data["|B|_z|r,a"] / np.abs(data["B^zeta"]) + - data["|B|"] + * data["B^zeta_z|r,a"] + * np.sign(data["B^zeta"]) + / data["B^zeta"] ** 2, + ) + np.testing.assert_allclose(data["B"], cross(data["grad(psi)"], data["grad(alpha)"])) + np.testing.assert_allclose(data["B^zeta"], data["psi_r"] / data["sqrt(g)_Clebsch"]) + np.testing.assert_allclose(data["B^phi"], data["psi_r"] / data["sqrt(g)_PEST"]) + np.testing.assert_allclose(data["B^phi"], dot(data["B"], data["grad(phi)"])) + np.testing.assert_allclose(data["B_phi"], data["B"][:, 1]) + np.testing.assert_allclose( + data["gbdrift (secular)"], data["gbdrift (secular)/phi"] * data["phi"] + ) @pytest.mark.unit diff --git a/tests/test_data_index.py b/tests/test_data_index.py index adeccede55..2f7dd72aea 100644 --- a/tests/test_data_index.py +++ b/tests/test_data_index.py @@ -69,9 +69,12 @@ def test_data_index_deps(): pattern_name = re.compile(r"(? 0: + R_lmn[(-m, -n)] -= C_r[(m, n)] + if (m, n) in C_z: + Z_lmn[(-m, abs(n))] += C_z[(m, n)] + if n < 0: + Z_lmn[(m, n)] -= C_z[(m, n)] + elif n > 0: + Z_lmn[(m, -n)] += C_z[(m, n)] + + grid = LinearGrid(rho=1, M=5, N=5) + R_bench = TestLaplace._manual_transform( + np.array(list(R_lmn.values())), + np.array([mn[0] for mn in R_lmn.keys()]), + np.array([mn[1] for mn in R_lmn.keys()]), + -grid.nodes[:, 1], # theta is flipped + grid.nodes[:, 2], + ) + R_merk = TestLaplace._merkel_transform( + np.array(list(C_r.values())), + np.array([mn[0] for mn in C_r.keys()]), + np.array([mn[1] for mn in C_r.keys()]), + -grid.nodes[:, 1], # theta is flipped + grid.nodes[:, 2], + ) + Z_bench = TestLaplace._manual_transform( + np.array(list(Z_lmn.values())), + np.array([mn[0] for mn in Z_lmn.keys()]), + np.array([mn[1] for mn in Z_lmn.keys()]), + -grid.nodes[:, 1], # theta is flipped + grid.nodes[:, 2], + ) + Z_merk = TestLaplace._merkel_transform( + np.array(list(C_z.values())), + np.array([mn[0] for mn in C_z.keys()]), + np.array([mn[1] for mn in C_z.keys()]), + -grid.nodes[:, 1], # theta is flipped + grid.nodes[:, 2], + fun=np.sin, + ) + np.testing.assert_allclose(R_bench, R_merk) + np.testing.assert_allclose(Z_bench, Z_merk) + with pytest.warns(UserWarning, match="Left handed"): + surf = FourierRZToroidalSurface( + R_lmn=list(R_lmn.values()), + Z_lmn=list(Z_lmn.values()), + modes_R=list(R_lmn.keys()), + modes_Z=list(Z_lmn.keys()), + ) + surf_data = surf.compute(["R", "Z"], grid=grid) + np.testing.assert_allclose(surf_data["R"], R_merk) + np.testing.assert_allclose(surf_data["Z"], Z_merk) + return surf + + @staticmethod + def _manual_transform(coef, m, n, theta, zeta): + """Evaluates Double Fourier Series of form G_n^m at theta and zeta pts.""" + op_four = np.where( + ((m < 0) & (n < 0))[:, np.newaxis], + np.sin(np.abs(m)[:, np.newaxis] * theta) + * np.sin(np.abs(n)[:, np.newaxis] * zeta), + n[:, np.newaxis] * zeta * np.nan, + ) + op_three = np.where( + ((m < 0) & (n >= 0))[:, np.newaxis], + np.sin(np.abs(m)[:, np.newaxis] * theta) * np.cos(n[:, np.newaxis] * zeta), + op_four, + ) + op_two = np.where( + ((m >= 0) & (n < 0))[:, np.newaxis], + np.cos(m[:, np.newaxis] * theta) * np.sin(np.abs(n)[:, np.newaxis] * zeta), + op_three, + ) + op_one = np.where( + ((m >= 0) & (n >= 0))[:, np.newaxis], + np.cos(m[:, np.newaxis] * theta) * np.cos(n[:, np.newaxis] * zeta), + op_two, + ) + return np.sum(coef[:, np.newaxis] * op_one, axis=0) + + @staticmethod + def _merkel_transform(coef, m, n, theta, zeta, fun=np.cos): + """Evaluates double Fourier series of form cos(m theta + n zeta).""" + return np.sum( + coef[:, np.newaxis] + * fun(m[:, np.newaxis] * theta + n[:, np.newaxis] * zeta), + axis=0, + ) class TestBouncePoints: diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 5403aef6ba..768f2cab0a 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -26,6 +26,7 @@ initialize_modular_coils, ) from desc.compute import get_transforms +from desc.compute._laplace import Options as LaplaceOptions from desc.equilibrium import Equilibrium from desc.examples import get from desc.geometry import FourierPlanarCurve, FourierRZToroidalSurface, FourierXYZCurve @@ -35,8 +36,10 @@ from desc.magnetic_fields import ( CurrentPotentialField, FourierCurrentPotentialField, + FreeSurfaceOuterField, OmnigenousField, PoloidalMagneticField, + SourceFreeField, SplineMagneticField, ToroidalMagneticField, VerticalMagneticField, @@ -95,7 +98,7 @@ Volume, get_NAE_constraints, ) -from desc.objectives._free_boundary import BoundaryErrorNESTOR +from desc.objectives._free_boundary import BoundaryErrorNESTOR, FreeSurfaceError from desc.objectives.nae_utils import ( _calc_1st_order_NAE_coeffs, _calc_2nd_order_NAE_coeffs, @@ -2165,6 +2168,115 @@ def test_objective_against_compute_ballooning(self): lam = w0 * lam.sum(axis=(-1, -2, -3)) + w1 * lam.max(axis=(-1, -2, -3)) np.testing.assert_allclose(obj.compute(eq.params_dict), lam) + @pytest.mark.unit + @pytest.mark.parametrize("solve_method", ["fixed_point", "gmres", "direct"]) + def test_objective_against_compute_free_surface_error(self, solve_method): + """Test FreeSurfaceError against the underlying |K_vc|^2 compute quantity.""" + eq = get("W7-X") + grid = LinearGrid(rho=np.array([1.0]), M=4, N=4, NFP=eq.NFP, sym=False) + B = ToroidalMagneticField(5, 1) + field = FreeSurfaceOuterField(eq.surface, M=grid.M, N=grid.N, B_coil=B) + obj = FreeSurfaceError( + eq, + field, + grid=grid, + options=LaplaceOptions(solve_method=solve_method), + ) + obj.build(verbose=0) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResolutionWarning) + inner = eq.compute( + obj._inner_keys, + grid=grid, + params=eq.params_dict, + transforms=obj._constants["eq_transforms"], + profiles=obj._constants["profiles"], + override_grid=False, + ) + field_params = { + "R_lmn": eq.params_dict["Rb_lmn"], + "Z_lmn": eq.params_dict["Zb_lmn"], + "I": inner["I"][grid.unique_rho_idx[-1]], + "Y": field.Y, + } + outer_data = {key: inner[key] for key in obj._reuseable_keys} + outer_data["interpolator"] = obj._constants["interpolator"] + outer_data["B0*n"] = obj._phi_sec_dot_n(field_params, inner) + outer, _ = field.compute( + "|K_vc|^2", + grid=grid, + params=field_params, + transforms=obj._constants["eval_transforms"], + data=outer_data, + override_grid=False, + options=LaplaceOptions(*obj._options)._replace( + solve_method=solve_method, + Phi_0=obj._constants["initial_guess"], + ), + B_coil=B, + ) + expected = (outer["|K_vc|^2"] - inner["|B|^2"] - 2 * mu_0 * inner["p"]) * inner[ + "|e_theta x e_zeta|" + ] + + np.testing.assert_allclose( + obj.compute(eq.params_dict, obj.things[1].params_dict), expected + ) + + @pytest.mark.unit + def test_free_surface_error_optimizes_sheet_current(self): + """Test FreeSurfaceError exposes I_sheet as an optimizable parameter.""" + eq = get("W7-X") + grid = LinearGrid(rho=np.array([1.0]), M=2, N=2, NFP=eq.NFP, sym=False) + field = FreeSurfaceOuterField( + eq.surface, M=grid.M, N=grid.N, B_coil=ToroidalMagneticField(5, 1) + ) + obj = ObjectiveFunction( + FreeSurfaceError( + eq, + field, + grid=grid, + options=LaplaceOptions(solve_method="direct"), + ) + ) + obj.build(verbose=0) + + assert obj.things[1].optimizable_params == ["I_sheet"] + x0 = obj.x() + idx = obj.things[0].dim_x + obj.things[1].x_idx["I_sheet"][0] + grad = obj.grad(x0) + assert np.isfinite(grad[idx]) + assert not np.isclose(grad[idx], 0) + + step = 1e-5 * np.sign(grad[idx]) + x1 = x0.at[idx].add(-step) + assert obj.compute_scalar(x1) < obj.compute_scalar(x0) + + @pytest.mark.unit + def test_free_surface_error_can_fix_sheet_current(self): + """Test FreeSurfaceError can fix I_sheet to zero.""" + eq = get("W7-X") + grid = LinearGrid(rho=np.array([1.0]), M=2, N=2, NFP=eq.NFP, sym=False) + field = FreeSurfaceOuterField( + eq.surface, M=grid.M, N=grid.N, B_coil=ToroidalMagneticField(5, 1) + ) + obj = ObjectiveFunction( + FreeSurfaceError( + eq, + field, + grid=grid, + fix_I_sheet=True, + options=LaplaceOptions(solve_method="direct"), + ) + ) + obj.build(verbose=0) + + assert len(obj.things) == 1 + assert obj.things[0] is eq + assert obj.dim_x == eq.dim_x + assert np.isfinite(obj.compute_scalar(obj.x())) + @pytest.mark.unit def test_generic_with_kwargs(self): """Test GenericObjective with keyword arguments. Related to issue #1224.""" @@ -3310,6 +3422,7 @@ class TestComputeScalarResolution: CoilSetLinkingNumber, CoilSetMinDistance, CoilTorsion, + FreeSurfaceError, FusionPower, GenericObjective, HeatingPowerISS04, @@ -3483,6 +3596,44 @@ def test_compute_scalar_resolution_vacuum_boundary_error(self): f[i] = obj.compute_scalar(obj.x()) np.testing.assert_allclose(f, f[-1], rtol=5e-2) + @pytest.mark.regression + @pytest.mark.parametrize("flag", [True, False]) + def test_compute_scalar_resolution_free_surface_error(self, flag): + """FreeSurfaceError.""" + pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) + iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) + surf = FourierRZToroidalSurface( + R_lmn=[4.0, 1.0], + modes_R=[[0, 0], [1, 0]], + Z_lmn=[-1.0], + modes_Z=[[-1, 0]], + NFP=1, + ) + eq = Equilibrium(M=6, N=0, Psi=1.0, surface=surf, pressure=pres, iota=iota) + + f = np.zeros_like(self.res_array, dtype=float) + for i, res in enumerate(self.res_array): + eq.change_resolution( + L_grid=int(eq.L * res), M_grid=int(eq.M * res), N_grid=int(eq.N * res) + ) + B = ToroidalMagneticField(5, 1) + field = ( + FreeSurfaceOuterField(eq.surface, eq.M, eq.N, B_coil=B) + if flag + else SourceFreeField(eq.surface, eq.M, eq.N, B0=B) + ) + grid = LinearGrid( + rho=np.array([1.0]), + M=eq.M, + N=eq.N, + NFP=eq.NFP if eq.N > 0 else 64, + sym=False, + ) + obj = ObjectiveFunction(FreeSurfaceError(eq, field, grid=grid)) + obj.build() + f[i] = obj.compute_scalar(obj.x()) + np.testing.assert_allclose(f, f[-1], rtol=5e-2) + @pytest.mark.regression def test_compute_scalar_resolution_quadratic_flux(self): """QuadraticFlux.""" @@ -3782,7 +3933,7 @@ def test_compute_scalar_resolution_coils(self, objective): f[i] = obj.compute_scalar(obj.x()) np.testing.assert_allclose(f, f[-1], rtol=1e-2, atol=1e-12) - @pytest.mark.unit + @pytest.mark.regression def test_compute_scalar_resolution_linking_current(self): """LinkingCurrentConsistency.""" coil = FourierPlanarCoil(center=[10, 1, 0]) @@ -3831,6 +3982,7 @@ class TestObjectiveNaNGrad: CoilTorsion, EffectiveRipple, ForceBalanceAnisotropic, + FreeSurfaceError, DeflationOperator, FusionPower, GammaC, @@ -3981,6 +4133,29 @@ def test_objective_no_nangrad_boundary_error(self): g = obj.grad(obj.x(eq, ext_field)) assert not np.any(np.isnan(g)), "boundary error" + @pytest.mark.unit + @pytest.mark.parametrize("flag", [True, False]) + def test_objective_no_nangrad_free_surface_error(self, flag): + """FreeSurfaceError.""" + eq = get("W7-X") + B = ToroidalMagneticField(5, 1) + field = ( + (FreeSurfaceOuterField)(eq.surface, 2, 2, B_coil=B) + if flag + else SourceFreeField(eq.surface, 2, 2, B0=B) + ) + obj = ObjectiveFunction( + FreeSurfaceError( + eq, + field, + grid=LinearGrid(M=3, N=3, NFP=eq.NFP), + options=LaplaceOptions(solve_method="fixed_point"), + ) + ) + obj.build() + g = obj.grad(obj.x()) + assert not np.any(np.isnan(g)), "free surface error" + @pytest.mark.unit def test_objective_no_nanjac_boundary_error_kinetic_profiles(self): """Test BoundaryError with kinetic profiles. Related to GH Issue #1712."""